From 3abc12cd5cadc10c542c8446d72e41ff62a06bb3 Mon Sep 17 00:00:00 2001 From: Thomas de Jong Date: Wed, 15 Apr 2026 21:23:57 +0200 Subject: [PATCH 1/5] feat!: refactor to config-based API with improved parsing and trust chain --- CHANGELOG.md | 34 +- Justfile | 4 +- README.md | 568 ++++----- benchmark_test.go | 321 ++--- call_options.go | 107 -- chain_analysis.go | 233 ---- chain_validation.go | 57 - classify.go | 60 + classify_test.go | 37 + config.go | 371 +++--- config_test.go | 675 +++-------- config_validation.go | 90 -- doc.go | 174 +-- events.go | 14 - example_test.go | 260 ++-- extract_from_test.go | 419 ------- extractor.go | 191 +-- extractor_test.go | 1068 ++++++----------- forwarded_test.go | 206 ---- input.go | 51 + request_input_test.go => input_test.go | 137 +-- ip_parse_test.go | 385 ------ ip_validation.go | 88 -- logger.go | 25 - logger_test.go | 410 ------- metrics.go | 28 - metrics_test.go | 503 -------- observability.go | 72 ++ observability_test.go | 786 ++++++++++++ options.go | 203 ---- parse_benchmark_test.go | 18 + chain_capacity.go => parse_chain_capacity.go | 5 +- parse_errors.go | 12 + forwarded.go => parse_forwarded.go | 52 +- parse_forwarded_test.go | 125 ++ parse_fuzz_test.go | 117 ++ ip_parse.go => parse_ip.go | 76 +- parse_ip_test.go | 174 +++ parse_remote_addr.go | 21 + parse_remote_addr_test.go | 45 + parse_xff.go | 54 + parse_xff_test.go | 68 ++ parser_fuzz_test.go | 168 --- presets.go | 56 +- presets_test.go | 71 +- request_input.go | 153 --- resolver.go | 315 +++++ resolver_test.go | 590 +++++++++ sources.go => source.go | 215 ++-- source_build_test.go | 69 ++ source_chain.go | 121 -- source_chain_extract.go | 106 ++ source_chain_extract_test.go | 401 +++++++ source_chain_test.go | 431 ------- source_chained.go | 56 + source_chained_test.go | 246 ++++ source_compile.go | 45 + source_execution.go | 439 +++++++ source_failure.go | 32 + source_forwarded.go | 103 -- source_forwarded_test.go | 205 ---- source_helpers.go | 19 - source_remote_addr_extract.go | 27 + source_remote_addr_extract_test.go | 139 +++ source_remote_addr_test.go | 81 -- source_request.go | 104 ++ source_single_header.go | 157 +-- source_single_header_test.go | 320 ++--- test_helpers_test.go | 22 +- trust_benchmark_test.go | 103 ++ trust_chain.go | 170 +++ trust_chain_test.go | 184 +++ trust_client_ip.go | 145 +++ trust_client_ip_test.go | 141 +++ trusted_proxy_matcher.go => trust_matcher.go | 48 +- ...y_matcher_test.go => trust_matcher_test.go | 42 +- types.go | 5 - types_test.go | 34 +- xff_parse.go | 38 - xff_test.go | 771 ------------ 80 files changed, 6798 insertions(+), 7918 deletions(-) delete mode 100644 call_options.go delete mode 100644 chain_analysis.go delete mode 100644 chain_validation.go create mode 100644 classify.go create mode 100644 classify_test.go delete mode 100644 config_validation.go delete mode 100644 events.go delete mode 100644 extract_from_test.go delete mode 100644 forwarded_test.go create mode 100644 input.go rename request_input_test.go => input_test.go (58%) delete mode 100644 ip_parse_test.go delete mode 100644 ip_validation.go delete mode 100644 logger.go delete mode 100644 logger_test.go delete mode 100644 metrics.go delete mode 100644 metrics_test.go create mode 100644 observability.go create mode 100644 observability_test.go delete mode 100644 options.go create mode 100644 parse_benchmark_test.go rename chain_capacity.go => parse_chain_capacity.go (85%) create mode 100644 parse_errors.go rename forwarded.go => parse_forwarded.go (65%) create mode 100644 parse_forwarded_test.go create mode 100644 parse_fuzz_test.go rename ip_parse.go => parse_ip.go (69%) create mode 100644 parse_ip_test.go create mode 100644 parse_remote_addr.go create mode 100644 parse_remote_addr_test.go create mode 100644 parse_xff.go create mode 100644 parse_xff_test.go delete mode 100644 parser_fuzz_test.go delete mode 100644 request_input.go create mode 100644 resolver.go create mode 100644 resolver_test.go rename sources.go => source.go (51%) create mode 100644 source_build_test.go delete mode 100644 source_chain.go create mode 100644 source_chain_extract.go create mode 100644 source_chain_extract_test.go delete mode 100644 source_chain_test.go create mode 100644 source_chained.go create mode 100644 source_chained_test.go create mode 100644 source_compile.go create mode 100644 source_execution.go create mode 100644 source_failure.go delete mode 100644 source_forwarded.go delete mode 100644 source_forwarded_test.go delete mode 100644 source_helpers.go create mode 100644 source_remote_addr_extract.go create mode 100644 source_remote_addr_extract_test.go delete mode 100644 source_remote_addr_test.go create mode 100644 source_request.go create mode 100644 trust_benchmark_test.go create mode 100644 trust_chain.go create mode 100644 trust_chain_test.go create mode 100644 trust_client_ip.go create mode 100644 trust_client_ip_test.go rename trusted_proxy_matcher.go => trust_matcher.go (90%) rename trusted_proxy_matcher_test.go => trust_matcher_test.go (54%) delete mode 100644 xff_parse.go delete mode 100644 xff_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e9f16a..50db953 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,27 +8,31 @@ The format is based on Keep a Changelog and this project follows Semantic Versio ### Added -- Parser fuzz targets for `parseIP`, `parseRemoteAddr`, `parseXFFValues`, and `parseForwardedValues`, plus new `just fuzz` and `just fuzz-one` maintainer commands. -- Expanded regression coverage for extraction behavior, request-input header adaptation, logger/metrics reporting, typed error formatting, and Prometheus adapter examples/tests. +- `Resolver`, `ResolverConfig`, `PreferredFallback`, and `Resolution` as the request-scoped API for strict and preferred client IP resolution. +- `StrictResolutionFromContext` and `PreferredResolutionFromContext` for reusing cached resolver state across middleware. +- `Input`, `ExtractInput`, and `ExtractInputAddr` for framework-agnostic request handling. +- `ParseRemoteAddr` helper. +- `ClassifyError`, `ResultKind`, and result classification constants for coarse-grained policy handling. +- Exported `SecurityEvent...` constants and public `SourceStaticFallback`. +- Updated docs, examples, presets, and Prometheus examples around the resolver-first architecture. ### Changed -- Option naming is now consistently `With...` for policy options (`WithTrustedProxyPrefixes`, `WithMinTrustedProxies`, `WithMaxTrustedProxies`, `WithAllowPrivateIPs`, `WithAllowedReservedClientPrefixes`, `WithMaxChainLength`, and related trust helpers). -- Internal chain extraction logic from `xff.go` is split into focused files (`chain_capacity.go`, `chain_validation.go`, `xff_parse.go`, `chain_analysis.go`) with no behavior change. -- No-op call options now reuse the existing config and source chain when policy is unchanged. -- Typed-nil `RequestInput.Headers` providers are treated as absent instead of being invoked. -- The optional Prometheus adapter module now depends on `github.com/abczzz13/clientip v0.0.6`. -- Typed source API with an opaque `Source` type, `HeaderSource(string)`, fully typed `Extraction.Source` / `ExtractionError.Source`, and `WithSourcePriority(...Source)`. -- Per-call policy API via `CallOption` and helpers such as `WithCallSecurityMode`, `WithCallSourcePriority`, and `WithCallTrustedProxyPrefixes`. -- **BREAKING:** `Extractor.Extract`, `Extractor.ExtractAddr`, `Extractor.ExtractFrom`, and `Extractor.ExtractAddrFrom` now accept `...CallOption` (instead of `...OverrideOptions`). -- **BREAKING:** `Extractor.Extract(nil)` and `Extractor.ExtractAddr(nil)` now return `ErrNilRequest`. -- **BREAKING:** `NormalizeSourceName` has been removed; use `HeaderSource(name).String()` when you need the canonical identifier for an arbitrary header name. -- **BREAKING:** custom `Source` text/JSON encoding now uses canonical MIME header names (for example `Cf-Connecting-Ip`) instead of underscore-normalized identifiers such as `cf_connecting_ip`; `Source.String()` still returns the underscore-normalized identifier. +- **BREAKING:** `Resolver` is now the primary documented API. `Extractor` remains as the strict low-level primitive. +- **BREAKING:** `RequestInput` is renamed to `Input`, `ExtractFrom` is renamed to `ExtractInput`, and `ExtractAddrFrom` is renamed to `ExtractInputAddr`. +- **BREAKING:** `Overrides`, `ExtractWith`, `ExtractAddrWith`, `ExtractFromWith`, and `ExtractAddrFromWith` are removed. +- **BREAKING:** `SecurityMode` is removed. Preferred behavior now lives on `Resolver.ResolvePreferred`. +- **BREAKING:** `ResolverConfig` now uses explicit `PreferredFallback` selection instead of competing fallback knobs. +- **BREAKING:** Preferred fallback is explicit resolver behavior with `Resolution.FallbackUsed`; fallback does not emit separate metrics or log events in this phase. +- **BREAKING:** `SourceStaticFallback` remains public but is resolver-result-only; it cannot be used in `Config.Sources`. +- Presets remain `Config` helpers and now document resolver-oriented usage more clearly. +- Prometheus integration is constructor-based: build metrics with `prometheus.New()` or `prometheus.NewWithRegisterer(...)` and assign them through `Config.Metrics`. +- Internal orchestration now sits behind `internal/engine` and concrete source execution behind `internal/source`. ### Removed -- Per-call `OverrideOptions` and `Set(...)` in favor of `CallOption`. -- One-shot helpers `ExtractWithOptions`, `ExtractAddrWithOptions`, `ExtractFromWithOptions`, and `ExtractAddrFromWithOptions`. +- Per-call override APIs and the old security-mode split. +- The older extraction naming built around `RequestInput` and `ExtractFrom`. ## [0.0.6] - 2026-02-18 diff --git a/Justfile b/Justfile index c9e9bae..4fc4207 100644 --- a/Justfile +++ b/Justfile @@ -42,7 +42,9 @@ bench-all *args: bench-save name pattern="." count="6" *args: @mkdir -p .bench - @bash -eo pipefail -c 'outfile="$1"; echo "Saving benchmark sample to $outfile"; go test -run "^$" -bench "{{pattern}}" -benchmem -count={{count}} ./... {{args}} | tee "$outfile"; GOWORK={{adapter_gowork}} go -C prometheus test -run "^$" -bench "{{pattern}}" -benchmem -count={{count}} ./... {{args}} | tee -a "$outfile"' _ ".bench/{{name}}.txt" + @echo "Saving benchmark sample to .bench/{{name}}.txt" + @go test -run "^$" -bench "{{pattern}}" -benchmem -count={{count}} ./... {{args}} > ".bench/{{name}}.txt" + @GOWORK={{adapter_gowork}} go -C prometheus test -run "^$" -bench "{{pattern}}" -benchmem -count={{count}} ./... {{args}} >> ".bench/{{name}}.txt" bench-compare-saved before after: @just bench-compare ".bench/{{before}}.txt" ".bench/{{after}}.txt" diff --git a/README.md b/README.md index 9090cea..2c8e4e9 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ [![Go Reference](https://pkg.go.dev/badge/github.com/abczzz13/clientip.svg)](https://pkg.go.dev/github.com/abczzz13/clientip) [![License](https://img.shields.io/github/license/abczzz13/clientip)](LICENSE) -Secure client IP extraction for `net/http` and framework-agnostic request inputs with trusted proxy validation, configurable source priority, and optional logging/metrics. +Secure client IP extraction for `net/http` and framework-agnostic request inputs with trusted proxy validation, explicit source modeling, and request-scoped resolver caching. ## Stability This project is pre-`v1.0.0` and still before `v0.1.0`, so public APIs may change as the package evolves. -Any breaking changes will be called out in `CHANGELOG.md`. +Any breaking changes are called out in `CHANGELOG.md`. ## Install @@ -23,114 +23,112 @@ Optional Prometheus adapter: go get github.com/abczzz13/clientip/prometheus ``` -```go -import "github.com/abczzz13/clientip" -``` - -## Compatibility +## Choose the API -- Core module (`github.com/abczzz13/clientip`) supports Go `1.21+`. -- Optional Prometheus adapter (`github.com/abczzz13/clientip/prometheus`) has a minimum Go version of `1.21`; CI currently validates consumer mode on Go `1.21.x` and `1.26.x`. -- Prometheus client dependency in the adapter is pinned to `github.com/prometheus/client_golang v1.21.1`. +- `Resolver` is the primary API. Use it when middleware, handlers, or framework adapters need to resolve the client IP once and reuse the result on the same request. +- `Extractor` is the low-level strict primitive. Use it when you only need one extraction call and do not need request-scoped caching or preferred fallback. +- `Input` is the framework-agnostic carrier for non-`net/http` integrations. +- `ParseRemoteAddr` and `ClassifyError` are small helpers for explicit fallback and policy code. -## Quick start +Construct an `Extractor` once and reuse it. Build a `Resolver` on top when you want strict or preferred request-scoped resolution. -By default, `New()` extracts from `RemoteAddr` only. +## Quick Start -### Presets (recommended) +Use `Resolver.ResolveStrict` for security-sensitive or audit-oriented decisions. -Use these when you want setup by deployment type instead of low-level options: +```go +extractor, err := clientip.New(clientip.PresetLoopbackReverseProxy()) +if err != nil { + log.Fatal(err) +} -- `PresetDirectConnection()` app receives traffic directly (no trusted proxy headers) -- `PresetLoopbackReverseProxy()` reverse proxy on same host (`127.0.0.1` / `::1`) -- `PresetVMReverseProxy()` typical VM/private-network reverse proxy setup -- `PresetPreferredHeaderThenXFFLax("X-Frontend-IP")` prefer custom header, then `X-Forwarded-For`, then `RemoteAddr` (lax fallback) +resolver, err := clientip.NewResolver(extractor, clientip.ResolverConfig{}) +if err != nil { + log.Fatal(err) +} -#### Which preset should I use? +req := &http.Request{RemoteAddr: "127.0.0.1:12345", Header: make(http.Header)} +req.Header.Set("X-Forwarded-For", "8.8.8.8") -| If your setup looks like... | Start with... | -| --- | --- | -| App is directly internet-facing (no reverse proxy) | `PresetDirectConnection()` | -| NGINX/Caddy runs on the same host and proxies to your app | `PresetLoopbackReverseProxy()` | -| App runs on a VM/private network behind one or more internal proxies | `PresetVMReverseProxy()` | -| You have a best-effort custom header and want fallback to XFF | `PresetPreferredHeaderThenXFFLax("X-Frontend-IP")` | +req, resolution := resolver.ResolveStrict(req) +if resolution.Err != nil { + log.Fatal(resolution.Err) +} -Preset examples: +fmt.Printf("Client IP: %s from %s\n", resolution.IP, resolution.Source) -```go -// Typical VM setup (reverse proxy + private networking) -vmExtractor, err := clientip.New( - clientip.PresetVMReverseProxy(), -) - -// Prefer a best-effort header, then fallback to XFF and RemoteAddr -fallbackExtractor, err := clientip.New( - clientip.WithTrustedLoopbackProxy(), - clientip.PresetPreferredHeaderThenXFFLax("X-Frontend-IP"), -) - -_ = vmExtractor -_ = fallbackExtractor +if cached, ok := clientip.StrictResolutionFromContext(req.Context()); ok { + fmt.Printf("Cached: %s\n", cached.IP) +} ``` -### Simple (no proxy configuration) +## Preferred Resolution And Fallback + +Use `Resolver.ResolvePreferred` when best-effort client IPs are operationally useful, such as rate limiting, analytics, or request tracing. ```go -extractor, err := clientip.New() +extractor, err := clientip.New(clientip.Config{ + TrustedProxyPrefixes: clientip.LoopbackProxyPrefixes(), + Sources: []clientip.Source{clientip.SourceXForwardedFor}, +}) if err != nil { log.Fatal(err) } -ip, err := extractor.ExtractAddr(req) +resolver, err := clientip.NewResolver(extractor, clientip.ResolverConfig{ + PreferredFallback: clientip.PreferredFallbackRemoteAddr, +}) if err != nil { - fmt.Printf("Failed: %v\n", err) - return + log.Fatal(err) +} + +req := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: make(http.Header)} + +_, resolution := resolver.ResolvePreferred(req) +if resolution.Err != nil { + log.Fatal(resolution.Err) } -fmt.Printf("Client IP: %s\n", ip) +fmt.Printf("Client IP: %s from %s (fallback=%t)\n", resolution.IP, resolution.Source, resolution.FallbackUsed) ``` -### Framework-friendly input API +Important fallback guidance: + +- Preferred fallback is explicit and only lives on `Resolver`. +- `PreferredFallbackRemoteAddr` is operationally useful, but it is not equivalent to validated proxy-header extraction. +- Preferred resolution is not suitable for authorization, ACLs, or other trust-boundary decisions. +- Fallback observability is result-only in this phase. Inspect `Resolution.FallbackUsed`; do not expect a separate fallback metric or log event. + +If you want a synthetic fallback value instead of `RemoteAddr`, set `ResolverConfig{PreferredFallback: clientip.PreferredFallbackStaticIP, StaticFallbackIP: ...}`. Successful static fallback reports `SourceStaticFallback`. -Use `ExtractFrom` when your framework does not expose `*http.Request` directly. +## Framework-Agnostic Input + +Use `Input` with either `Extractor` or `Resolver` when your framework does not expose `*http.Request` directly. ```go -input := clientip.RequestInput{ +input := clientip.Input{ Context: ctx, RemoteAddr: remoteAddr, Path: path, - Headers: headersProvider, // any type implementing Values(name string) []string + Headers: headersProvider, } -extraction, err := extractor.ExtractFrom(input) -if err != nil { +input, resolution := resolver.ResolveInputStrict(input) +if resolution.Err != nil { // handle error } -``` - -`http.Header` already implements the required header interface, so for `net/http` -style frameworks (Gin, Echo, Chi) you can keep using `Extract(req)` directly. -`ExtractFrom` only requests header names required by the configured -`WithSourcePriority(...)` sources. - -Call-time overrides apply to `ExtractFrom(...)` as well, so middleware can -switch to a different trusted header or fall back to `SourceRemoteAddr` -without constructing a second extractor. - -```go -// Gin -extraction, err := extractor.Extract(c.Request) - -// Echo -extraction, err := extractor.Extract(c.Request()) +if cached, ok := clientip.StrictResolutionFromContext(input.Context); ok { + _ = cached +} ``` -For `fasthttp`/Fiber style frameworks, provide a header adapter with -`HeaderValuesFunc` and preserve duplicate header lines: +`Input.Headers` must preserve repeated header lines as separate slice entries. Do not merge duplicate lines into a single comma-joined string. + +For `fasthttp`/Fiber style integrations: ```go -input := clientip.RequestInput{ +input := clientip.Input{ Context: c.UserContext(), RemoteAddr: c.Context().RemoteAddr().String(), Path: c.Path(), @@ -149,376 +147,224 @@ input := clientip.RequestInput{ } ``` -Important: do not merge repeated header lines into a single comma-joined value. -Single-IP sources (for example `X-Real-IP` or custom headers) rely on per-line -values to detect duplicates in strict mode. +## Presets + +Presets return a flat `clientip.Config` that you can pass directly to `New` or tweak before construction. -### Behind reverse proxies +- `PresetDirectConnection()` uses `RemoteAddr` only. +- `PresetLoopbackReverseProxy()` trusts loopback proxies and prioritizes `X-Forwarded-For` before `RemoteAddr`. +- `PresetVMReverseProxy()` trusts loopback and common private-network proxy ranges and prioritizes `X-Forwarded-For` before `RemoteAddr`. ```go -cidrs, err := clientip.ParseCIDRs("10.0.0.0/8", "172.16.0.0/12") +extractor, err := clientip.New(clientip.PresetVMReverseProxy()) if err != nil { log.Fatal(err) } +``` + +If you need to tweak a preset, modify the returned config before calling `New`: + +```go +cfg := clientip.PresetVMReverseProxy() +cfg.Sources = []clientip.Source{clientip.SourceForwarded, clientip.SourceRemoteAddr} -extractor, err := clientip.New( - // min=0 allows requests where proxy headers contain only the client IP - // (trusted RemoteAddr is validated separately). - clientip.WithTrustedProxyPrefixes(cidrs...), - clientip.WithMinTrustedProxies(0), - clientip.WithMaxTrustedProxies(3), - clientip.WithSourcePriority(clientip.SourceXForwardedFor, clientip.SourceRemoteAddr), - clientip.WithChainSelection(clientip.RightmostUntrustedIP), -) +extractor, err := clientip.New(cfg) if err != nil { log.Fatal(err) } ``` -### Custom header priority +Presets configure `Config`, not `ResolverConfig`. Preferred resolver fallback stays an explicit resolver-level choice. -```go -extractor, err := clientip.New( - clientip.WithTrustedPrivateProxyRanges(), - clientip.WithSourcePriority( - clientip.HeaderSource("CF-Connecting-IP"), - clientip.SourceXForwardedFor, - clientip.SourceRemoteAddr, - ), -) -``` +## Config -### Security mode (strict vs lax) +`Config` stays flat in the current API. -```go -// Strict is default and fails closed on security errors -// (including malformed Forwarded and invalid present source values). -strictExtractor, _ := clientip.New( - clientip.WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - clientip.WithSourcePriority(clientip.HeaderSource("X-Frontend-IP"), clientip.SourceXForwardedFor, clientip.SourceRemoteAddr), - clientip.WithSecurityMode(clientip.SecurityModeStrict), -) - -// Lax mode allows fallback to lower-priority sources after those errors. -laxExtractor, _ := clientip.New( - clientip.WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - clientip.WithSourcePriority(clientip.HeaderSource("X-Frontend-IP"), clientip.SourceXForwardedFor, clientip.SourceRemoteAddr), - clientip.WithSecurityMode(clientip.SecurityModeLax), -) -``` +Important fields: -### Logging (bring your own) +- `TrustedProxyPrefixes []netip.Prefix` +- `MinTrustedProxies int` +- `MaxTrustedProxies int` +- `AllowPrivateIPs bool` +- `AllowedReservedClientPrefixes []netip.Prefix` +- `MaxChainLength int` +- `ChainSelection ChainSelection` +- `DebugInfo bool` +- `Sources []Source` +- `Logger Logger` +- `Metrics Metrics` -By default, logging is disabled. Use `WithLogger` to opt in. +Useful helpers: -`WithLogger` accepts any implementation of: +- `DefaultConfig()` +- `ParseCIDRs(...string)` +- `LoopbackProxyPrefixes()` +- `PrivateProxyPrefixes()` +- `LocalProxyPrefixes()` +- `ProxyPrefixesFromAddrs(...netip.Addr)` -```go -type Logger interface { - WarnContext(context.Context, string, ...any) -} -``` +Built-in extractor sources: -This intentionally mirrors `slog.Logger.WarnContext`, so `*slog.Logger` -works directly with `WithLogger` (no adapter needed). +- `SourceForwarded` +- `SourceXForwardedFor` +- `SourceXRealIP` +- `SourceRemoteAddr` +- `HeaderSource(name)` for custom headers -The context passed to logger calls comes from `req.Context()` (`Extract`) or -`RequestInput.Context` (`ExtractFrom`), so trace/span IDs added by middleware -remain available in logs. +Resolver-only result source: -Structured log attributes are passed as alternating key/value pairs, matching -the style used by `slog`. +- `SourceStaticFallback` -When configured, the extractor emits warning logs for security-significant -conditions such as `multiple_headers`, `malformed_forwarded`, `chain_too_long`, -`untrusted_proxy`, `no_trusted_proxies`, `too_few_trusted_proxies`, and `too_many_trusted_proxies`. +`Extractor` walks `Config.Sources` in order. `ErrSourceUnavailable` allows the next source to run, while security-significant failures remain terminal. -```go -extractor, err := clientip.New( - clientip.WithLogger(slog.Default()), -) -``` +## Low-Level Extraction -For loggers without context-aware APIs, adapters can simply ignore `ctx`: +Use `Extractor` directly when you want strict extraction without request-scoped caching or preferred fallback. ```go -type stdLoggerAdapter struct{ l *log.Logger } +extractor, err := clientip.New(clientip.DefaultConfig()) +if err != nil { + log.Fatal(err) +} -func (a stdLoggerAdapter) WarnContext(_ context.Context, msg string, args ...any) { - a.l.Printf("WARN %s %v", msg, args) +extraction, err := extractor.Extract(req) +if err != nil { + log.Fatal(err) } -extractor, err := clientip.New( - clientip.WithLogger(stdLoggerAdapter{l: log.Default()}), -) +fmt.Printf("Client IP: %s from %s\n", extraction.IP, extraction.Source) ``` -Tiny adapters for other popular loggers: +Framework-agnostic extraction is also available: ```go -type zapAdapter struct{ l *zap.SugaredLogger } - -func (a zapAdapter) WarnContext(_ context.Context, msg string, args ...any) { - a.l.With(args...).Warn(msg) +extraction, err := extractor.ExtractInput(input) +if err != nil { + log.Fatal(err) } ``` +## Errors + +Typed errors remain the detailed error surface: + ```go -type logrusAdapter struct{ l *logrus.Logger } - -func (a logrusAdapter) WarnContext(_ context.Context, msg string, args ...any) { - fields := logrus.Fields{} - for i := 0; i+1 < len(args); i += 2 { - key, ok := args[i].(string) - if !ok { - continue - } - fields[key] = args[i+1] +_, resolution := resolver.ResolveStrict(req) +if resolution.Err != nil { + switch { + case errors.Is(resolution.Err, clientip.ErrMultipleSingleIPHeaders): + case errors.Is(resolution.Err, clientip.ErrInvalidForwardedHeader): + case errors.Is(resolution.Err, clientip.ErrUntrustedProxy): + case errors.Is(resolution.Err, clientip.ErrNoTrustedProxies): + case errors.Is(resolution.Err, clientip.ErrTooFewTrustedProxies): + case errors.Is(resolution.Err, clientip.ErrTooManyTrustedProxies): + case errors.Is(resolution.Err, clientip.ErrInvalidIP): + case errors.Is(resolution.Err, clientip.ErrSourceUnavailable): + case errors.Is(resolution.Err, clientip.ErrNilRequest): } - a.l.WithFields(fields).Warn(msg) } ``` +`ClassifyError` provides a smaller policy-oriented layer on top of those typed errors: + ```go -type zerologAdapter struct{ l zerolog.Logger } - -func (a zerologAdapter) WarnContext(_ context.Context, msg string, args ...any) { - event := a.l.Warn() - for i := 0; i+1 < len(args); i += 2 { - key, ok := args[i].(string) - if !ok { - continue - } - event = event.Interface(key, args[i+1]) - } - event.Msg(msg) +switch clientip.ClassifyError(resolution.Err) { +case clientip.ResultSuccess: +case clientip.ResultUnavailable: +case clientip.ResultInvalid: +case clientip.ResultUntrusted: +case clientip.ResultMalformed: +case clientip.ResultCanceled: +case clientip.ResultUnknown: } ``` -If your stack stores trace metadata in `context.Context`, enrich the adapter by -extracting that value and appending it to `args`. +`ResultUnknown` covers non-nil errors outside the package's standard extraction and resolution categories. -### Prometheus metrics (simple setup) +Typed chain-related errors expose additional context: -```go -import clientipprom "github.com/abczzz13/clientip/prometheus" +- `ProxyValidationError`: `Chain`, `TrustedProxyCount`, `MinTrustedProxies`, `MaxTrustedProxies` +- `InvalidIPError`: `Chain`, `ExtractedIP`, `Index`, `TrustedProxies` +- `RemoteAddrError`: `RemoteAddr` +- `ChainTooLongError`: `ChainLength`, `MaxLength` -extractor, err := clientip.New( - clientipprom.WithMetrics(), -) -``` +## Logging -### Prometheus metrics (custom registerer) +Logging is disabled by default. Set `Config.Logger` to opt in. ```go -import ( - clientipprom "github.com/abczzz13/clientip/prometheus" - "github.com/prometheus/client_golang/prometheus" -) - -registry := prometheus.NewRegistry() - -extractor, err := clientip.New( - clientipprom.WithRegisterer(registry), -) +extractor, err := clientip.New(clientip.Config{ + Logger: slog.Default(), +}) ``` -You can also construct metrics explicitly with `clientipprom.New()` or -`clientipprom.NewWithRegisterer(...)` and pass them via -`clientip.WithMetrics(...)`. - -## Options - -`New(opts...)` accepts one or more `Option` builders. +The logger interface intentionally matches `slog.Logger.WarnContext`: -Construct an extractor once and reuse it for all requests. - -- `WithTrustedProxyPrefixes(...netip.Prefix)` add trusted proxy network prefixes -- `WithTrustedLoopbackProxy()` trust loopback upstream proxies (`127.0.0.0/8`, `::1/128`) -- `WithTrustedPrivateProxyRanges()` trust private upstream proxy ranges (`10/8`, `172.16/12`, `192.168/16`, `fc00::/7`) -- `WithTrustedLocalProxyDefaults()` trust loopback + private proxy ranges -- `WithTrustedProxyAddrs(...netip.Addr)` add trusted upstream proxy host addresses -- `PresetDirectConnection()` remote-address only extraction preset -- `PresetLoopbackReverseProxy()` loopback reverse-proxy preset (`X-Forwarded-For`, then `RemoteAddr`) -- `PresetVMReverseProxy()` VM/private-network reverse-proxy preset (`X-Forwarded-For`, then `RemoteAddr`) -- `PresetPreferredHeaderThenXFFLax(string)` preferred-header fallback preset in lax mode -- `WithMinTrustedProxies(int)` / `WithMaxTrustedProxies(int)` set trusted-proxy count bounds for chain headers -- `WithAllowPrivateIPs(bool)` allow private client IPs -- `WithAllowedReservedClientPrefixes(...netip.Prefix)` explicitly allow selected reserved/special-use client ranges -- `ParseCIDRs(...string)` parse CIDR strings to `[]netip.Prefix` for typed options -- `WithMaxChainLength(int)` limit proxy chain length from `Forwarded`/`X-Forwarded-For` (default 100) -- `WithChainSelection(ChainSelection)` choose `RightmostUntrustedIP` (default) or `LeftmostUntrustedIP` -- `WithSourcePriority(...Source)` set source order; built-ins: `SourceForwarded`, `SourceXForwardedFor`, `SourceXRealIP`, `SourceRemoteAddr`; use `HeaderSource("CF-Connecting-IP")` for custom headers -- `WithSecurityMode(SecurityMode)` choose `SecurityModeStrict` (default) or `SecurityModeLax` -- `WithLogger(Logger)` inject logger implementation -- `WithMetrics(Metrics)` inject custom metrics implementation directly -- `WithMetricsFactory(func() (Metrics, error))` lazily construct metrics after option validation (last metrics option wins) -- `WithDebugInfo(bool)` include chain analysis in `Extraction.DebugInfo` - -Default source order is `SourceRemoteAddr`. - -Any header-based source requires trusted upstream proxy ranges (`WithTrustedProxyPrefixes` or one of the trust helpers). - -Prometheus adapter helpers from `github.com/abczzz13/clientip/prometheus`: +```go +type Logger interface { + WarnContext(context.Context, string, ...any) +} +``` -- `WithMetrics()` install Prometheus metrics on default registerer -- `WithRegisterer(prometheus.Registerer)` install Prometheus metrics on custom registerer -- `New()` / `NewWithRegisterer(prometheus.Registerer)` for explicit metrics construction +The context passed to logger calls comes from `req.Context()` (`Extract`) or `Input.Context` (`ExtractInput`). -Proxy count bounds (`min`/`max`) apply to trusted proxies present in `Forwarded` (from `for=` values) and `X-Forwarded-For`. -The immediate proxy (`RemoteAddr`) is validated for trust separately before either header is trusted. +Security event labels passed through `Metrics.RecordSecurityEvent(...)` are the stable exported `clientip.SecurityEvent...` constants. -`WithAllowedReservedClientPrefixes` only bypasses reserved/special-use filtering for matching ranges. -It does not bypass loopback/link-local/multicast/unspecified rejection, and private-IP policy remains controlled by `WithAllowPrivateIPs`. +## Prometheus Metrics -## Extraction +Construct Prometheus metrics explicitly and pass them through `Config.Metrics`. ```go -type Source struct { /* opaque */ } +import clientipprom "github.com/abczzz13/clientip/prometheus" -type Extraction struct { - IP netip.Addr - Source Source - TrustedProxyCount int - DebugInfo *ChainDebugInfo +metrics, err := clientipprom.New() +if err != nil { + panic(err) } -type HeaderValues interface { - Values(name string) []string +extractor, err := clientip.New(clientip.Config{Metrics: metrics}) +if err != nil { + panic(err) } -type RequestInput struct { - Context context.Context - RemoteAddr string - Path string - Headers HeaderValues +resolver, err := clientip.NewResolver(extractor, clientip.ResolverConfig{}) +if err != nil { + panic(err) } - -func (e *Extractor) Extract(req *http.Request, callOpts ...CallOption) (Extraction, error) -func (e *Extractor) ExtractAddr(req *http.Request, callOpts ...CallOption) (netip.Addr, error) -func (e *Extractor) ExtractFrom(input RequestInput, callOpts ...CallOption) (Extraction, error) -func (e *Extractor) ExtractAddrFrom(input RequestInput, callOpts ...CallOption) (netip.Addr, error) ``` -When `Extract` returns a non-nil error, the returned `Extraction` value is -best-effort metadata only (typically `Source` when available). For chain -diagnostics, inspect typed errors like `ProxyValidationError` and -`InvalidIPError`. - -`Source` is an opaque value type. Use the built-ins (`SourceForwarded`, -`SourceXForwardedFor`, `SourceXRealIP`, `SourceRemoteAddr`) or -`HeaderSource(...)` to construct it, compare `Source` values directly, and use -`Source.String()` when you need the canonical identifier. -For text/JSON encoding, built-ins use canonical identifiers while custom -headers preserve canonical MIME header names for lossless round-tripping. +With a custom registerer: ```go -source := clientip.HeaderSource("CF-Connecting-IP") -text, _ := source.MarshalText() - -source.String() // "cf_connecting_ip" -string(text) // "Cf-Connecting-Ip" -``` - -That means custom-header `Source.String()` values are for canonical identifiers -inside this package, while text/JSON encoding preserves the actual header name. - -Per-call options let you temporarily adjust policy for a single extraction: - -```go -extraction, err := extractor.Extract( - req, - clientip.WithCallSecurityMode(clientip.SecurityModeLax), -) -``` - -```go -extraction, err := extractor.ExtractFrom( - input, - clientip.WithCallSourcePriority( - clientip.HeaderSource("CF-Connecting-IP"), - clientip.SourceRemoteAddr, - ), -) -``` - -Multiple `CallOption` values are applied left-to-right; later values -win. Only policy fields are overrideable (logger and metrics stay fixed per -extractor instance). - -Available call options: - -- `WithCallTrustedProxyPrefixes(...netip.Prefix)` replaces the extractor's trusted proxy prefixes for that call -- `WithCallMinTrustedProxies(int)` / `WithCallMaxTrustedProxies(int)` -- `WithCallAllowPrivateIPs(bool)` -- `WithCallAllowedReservedClientPrefixes(...netip.Prefix)` replaces the extractor's reserved-prefix allowlist for that call -- `WithCallMaxChainLength(int)` -- `WithCallChainSelection(ChainSelection)` -- `WithCallSecurityMode(SecurityMode)` -- `WithCallDebugInfo(bool)` -- `WithCallSourcePriority(...Source)` - -## Errors +registry := prometheus.NewRegistry() -```go -_, err := extractor.Extract(req) +metrics, err := clientipprom.NewWithRegisterer(registry) if err != nil { - switch { - case errors.Is(err, clientip.ErrMultipleSingleIPHeaders): - // Duplicate single-IP header values received - case errors.Is(err, clientip.ErrInvalidForwardedHeader): - // Malformed Forwarded header - case errors.Is(err, clientip.ErrUntrustedProxy): - // Forwarded/XFF came from an untrusted immediate proxy - case errors.Is(err, clientip.ErrNoTrustedProxies): - // No trusted proxies found in the chain - case errors.Is(err, clientip.ErrTooFewTrustedProxies): - // Trusted proxy count is below configured minimum - case errors.Is(err, clientip.ErrTooManyTrustedProxies): - // Trusted proxy count exceeds configured maximum - case errors.Is(err, clientip.ErrInvalidIP): - // Invalid or implausible client IP - case errors.Is(err, clientip.ErrSourceUnavailable): - // Requested source was not present on this request - case errors.Is(err, clientip.ErrNilRequest): - // Extract/ExtractAddr received a nil *http.Request - } - - var mh *clientip.MultipleHeadersError - if errors.As(err, &mh) { - // Inspect mh.HeaderName, mh.HeaderCount, or mh.RemoteAddr - } + panic(err) } ``` -Typed chain-related errors expose additional context: +## Security Guidance -- `ProxyValidationError`: `Chain`, `TrustedProxyCount`, `MinTrustedProxies`, `MaxTrustedProxies` -- `InvalidIPError`: `Chain`, `ExtractedIP`, `Index`, `TrustedProxies` - -## Security notes +- Use `ResolveStrict` or `Extractor` for security-sensitive and audit-oriented behavior. +- Use `ResolvePreferred` only when explicit fallback is acceptable for operational reasons. +- Do not use preferred fallback for authorization, ACLs, or trust-boundary enforcement. +- Do not include multiple competing header-based sources for security decisions. +- Do not trust broad proxy CIDRs unless they are truly under your control. +- Header-based sources require `TrustedProxyPrefixes`. +- `LeftmostUntrustedIP` only makes sense when trusted proxy prefixes are configured. -- Parses RFC7239 `Forwarded` header (`for=` chain) and rejects malformed values -- Parses multiple `X-Forwarded-For` header lines as one chain (wire order preserved) -- Rejects multiple values for single-IP headers (for example repeated `X-Real-IP`) -- Requires the immediate proxy (`RemoteAddr`) to be trusted before honoring `Forwarded` or `X-Forwarded-For` (when trusted proxy prefixes are configured) -- Requires trusted proxy prefixes for any header-based source -- Allows at most one chain-header source (`Forwarded` or `X-Forwarded-For`) per extractor configuration -- Enforces trusted proxy count bounds and chain length -- Filters implausible IPs (loopback, multicast, reserved); optional private-IP and reserved-prefix allowlists -- Strict fail-closed behavior is the default (`SecurityModeStrict`) for security-significant errors and invalid present source values -- Set `WithSecurityMode(SecurityModeLax)` to continue fallback after security errors - -## Security anti-patterns +## Compatibility -- Do not include multiple competing header-based sources in `WithSourcePriority(...)` for security decisions (for example custom header + chain header fallback). Prefer one canonical trusted header plus `SourceRemoteAddr` fallback only when required. -- Do not enable `SecurityModeLax` for security-enforcement decisions (ACLs, fraud/risk controls, authz). Use strict mode and fail closed. -- Do not trust broad proxy CIDRs unless they are truly under your control. Keep trusted ranges minimal and explicit. -- Do not treat a missing/invalid source as benign in critical paths; monitor and remediate extraction errors. +- Core module (`github.com/abczzz13/clientip`) supports Go `1.21+`. +- Optional Prometheus adapter (`github.com/abczzz13/clientip/prometheus`) has a minimum Go version of `1.21`; CI currently validates consumer mode on Go `1.21.x` and `1.26.x`. +- Prometheus client dependency in the adapter is pinned to `github.com/prometheus/client_golang v1.21.1`. ## Performance -- O(n) in chain length; extractor is safe for concurrent reuse +- Extraction is `O(n)` in proxy-chain length. +- `Extractor` is safe for concurrent reuse. +- `Resolver` adds request-scoped caching on top of a reusable extractor. Benchmark workflow with `just`: @@ -535,7 +381,7 @@ just bench-compare-saved before after You can compare arbitrary files directly via `just bench-compare `. -## Maintainer notes (multi-module) +## Maintainer Notes (Multi-Module) - `prometheus/go.mod` intentionally does not use a local `replace` directive for `github.com/abczzz13/clientip`. - For local co-development, create an uncommitted workspace with `go work init . ./prometheus`. diff --git a/benchmark_test.go b/benchmark_test.go index 60e905f..bde0a75 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -3,12 +3,31 @@ package clientip import ( "context" "net/http" - "net/netip" "testing" ) +func mustBenchmarkExtractor(b *testing.B, cfg Config) *Extractor { + b.Helper() + extractor, err := New(cfg) + if err != nil { + b.Fatalf("New() error = %v", err) + } + return extractor +} + +func benchmarkExtractionLoop(b *testing.B, extract func() (Extraction, error)) { + b.Helper() + b.ResetTimer() + for i := 0; i < b.N; i++ { + result, err := extract() + if err != nil || !result.IP.IsValid() { + b.Fatal("extraction failed") + } + } +} + func BenchmarkExtract_RemoteAddr(b *testing.B) { - extractor, _ := New() + extractor, _ := New(DefaultConfig()) req := &http.Request{ RemoteAddr: "1.1.1.1:12345", Header: make(http.Header), @@ -24,10 +43,10 @@ func BenchmarkExtract_RemoteAddr(b *testing.B) { } func BenchmarkExtract_XForwardedFor_Simple(b *testing.B) { - extractor, _ := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor, _ := New(cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:12345", Header: make(http.Header), @@ -45,12 +64,12 @@ func BenchmarkExtract_XForwardedFor_Simple(b *testing.B) { func BenchmarkExtract_XForwardedFor_WithTrustedProxies(b *testing.B) { cidrs, _ := ParseCIDRs("10.0.0.0/8") - extractor, _ := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(2), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 2 + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor, _ := New(cfg) req := &http.Request{ RemoteAddr: "10.0.0.1:12345", Header: make(http.Header), @@ -67,10 +86,10 @@ func BenchmarkExtract_XForwardedFor_WithTrustedProxies(b *testing.B) { } func BenchmarkExtract_Forwarded_Simple(b *testing.B) { - extractor, _ := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceForwarded, SourceRemoteAddr), - ) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceForwarded, SourceRemoteAddr} + extractor, _ := New(cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:12345", Header: make(http.Header), @@ -87,10 +106,10 @@ func BenchmarkExtract_Forwarded_Simple(b *testing.B) { } func BenchmarkExtract_Forwarded_WithParams(b *testing.B) { - extractor, _ := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceForwarded, SourceRemoteAddr), - ) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceForwarded, SourceRemoteAddr} + extractor, _ := New(cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:12345", Header: make(http.Header), @@ -108,12 +127,12 @@ func BenchmarkExtract_Forwarded_WithParams(b *testing.B) { func BenchmarkExtract_XForwardedFor_LongChain(b *testing.B) { cidrs, _ := ParseCIDRs("10.0.0.0/8") - extractor, _ := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(5), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 5 + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor, _ := New(cfg) req := &http.Request{ RemoteAddr: "10.0.0.5:12345", Header: make(http.Header), @@ -131,13 +150,13 @@ func BenchmarkExtract_XForwardedFor_LongChain(b *testing.B) { func BenchmarkExtract_WithDebugInfo(b *testing.B) { cidrs, _ := ParseCIDRs("10.0.0.0/8") - extractor, _ := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(2), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - WithDebugInfo(true), - ) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 2 + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + cfg.DebugInfo = true + extractor, _ := New(cfg) req := &http.Request{ RemoteAddr: "10.0.0.1:12345", Header: make(http.Header), @@ -158,13 +177,13 @@ func BenchmarkExtract_WithDebugInfo(b *testing.B) { func BenchmarkExtract_LeftmostUntrustedSelection(b *testing.B) { cidrs, _ := ParseCIDRs("173.245.48.0/20") - extractor, _ := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(3), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - WithChainSelection(LeftmostUntrustedIP), - ) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 3 + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + cfg.ChainSelection = LeftmostUntrustedIP + extractor, _ := New(cfg) req := &http.Request{ RemoteAddr: "173.245.48.5:443", Header: make(http.Header), @@ -181,14 +200,10 @@ func BenchmarkExtract_LeftmostUntrustedSelection(b *testing.B) { } func BenchmarkExtract_CustomHeader(b *testing.B) { - extractor, _ := New( - WithTrustedLoopbackProxy(), - WithSourcePriority( - HeaderSource("CF-Connecting-IP"), - SourceXForwardedFor, - SourceRemoteAddr, - ), - ) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{HeaderSource("CF-Connecting-IP"), SourceXForwardedFor, SourceRemoteAddr} + extractor, _ := New(cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:12345", Header: make(http.Header), @@ -205,10 +220,10 @@ func BenchmarkExtract_CustomHeader(b *testing.B) { } func BenchmarkExtract_Fallback_MissingPreferredHeader(b *testing.B) { - extractor, _ := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXRealIP, SourceXForwardedFor, SourceRemoteAddr), - ) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXRealIP, SourceXForwardedFor, SourceRemoteAddr} + extractor, _ := New(cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:12345", Header: make(http.Header), @@ -225,7 +240,7 @@ func BenchmarkExtract_Fallback_MissingPreferredHeader(b *testing.B) { } func BenchmarkExtract_Parallel(b *testing.B) { - extractor, _ := New() + extractor, _ := New(DefaultConfig()) req := &http.Request{ RemoteAddr: "1.1.1.1:12345", Header: make(http.Header), @@ -243,8 +258,8 @@ func BenchmarkExtract_Parallel(b *testing.B) { } func BenchmarkExtractFrom_HTTP_RemoteAddr(b *testing.B) { - extractor, _ := New() - input := RequestInput{ + extractor, _ := New(DefaultConfig()) + input := Input{ Context: context.Background(), RemoteAddr: "1.1.1.1:12345", Headers: make(http.Header), @@ -252,7 +267,7 @@ func BenchmarkExtractFrom_HTTP_RemoteAddr(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - result, err := extractor.ExtractFrom(input) + result, err := extractor.ExtractInput(input) if err != nil || !result.IP.IsValid() { b.Fatal("extraction failed") } @@ -260,13 +275,13 @@ func BenchmarkExtractFrom_HTTP_RemoteAddr(b *testing.B) { } func BenchmarkExtractFrom_HTTP_XForwardedFor_Simple(b *testing.B) { - extractor, _ := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor, _ := New(cfg) headers := make(http.Header) headers.Set("X-Forwarded-For", "1.1.1.1") - input := RequestInput{ + input := Input{ Context: context.Background(), RemoteAddr: "127.0.0.1:12345", Headers: headers, @@ -274,7 +289,7 @@ func BenchmarkExtractFrom_HTTP_XForwardedFor_Simple(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - result, err := extractor.ExtractFrom(input) + result, err := extractor.ExtractInput(input) if err != nil || !result.IP.IsValid() { b.Fatal("extraction failed") } @@ -282,12 +297,12 @@ func BenchmarkExtractFrom_HTTP_XForwardedFor_Simple(b *testing.B) { } func BenchmarkExtractFrom_HeaderValuesFunc_XForwardedFor_Simple(b *testing.B) { - extractor, _ := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor, _ := New(cfg) xffValues := []string{"1.1.1.1"} - input := RequestInput{ + input := Input{ Context: context.Background(), RemoteAddr: "127.0.0.1:12345", Headers: HeaderValuesFunc(func(name string) []string { @@ -300,152 +315,74 @@ func BenchmarkExtractFrom_HeaderValuesFunc_XForwardedFor_Simple(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - result, err := extractor.ExtractFrom(input) + result, err := extractor.ExtractInput(input) if err != nil || !result.IP.IsValid() { b.Fatal("extraction failed") } } } -func BenchmarkParseIP(b *testing.B) { - testCases := []string{ - "1.1.1.1", - " 1.1.1.1 ", - "1.1.1.1:8080", - "[2606:4700:4700::1]", - "[2606:4700:4700::1]:8080", - `"1.1.1.1"`, - } - - for _, tc := range testCases { - b.Run(tc, func(b *testing.B) { - for i := 0; i < b.N; i++ { - ip := parseIP(tc) - if !ip.IsValid() { - b.Fatal("parsing failed") - } - } - }) - } -} - -func BenchmarkParseCIDRs(b *testing.B) { - cidrs := []string{ - "10.0.0.0/8", - "172.16.0.0/12", - "192.168.0.0/16", - "2606:4700:4700::/32", - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := ParseCIDRs(cidrs...) - if err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkIsTrustedProxy(b *testing.B) { - cidrs, _ := ParseCIDRs("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16") - extractor := &Extractor{ - config: &config{ - trustedProxyCIDRs: cidrs, - }, - } +func BenchmarkExtract_RequestVsInput_RemoteAddr(b *testing.B) { + extractor := mustBenchmarkExtractor(b, DefaultConfig()) + req := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: make(http.Header)} + input := Input{Context: context.Background(), RemoteAddr: "1.1.1.1:12345"} - testIPs := []netip.Addr{ - netip.MustParseAddr("10.0.0.1"), - netip.MustParseAddr("172.16.0.1"), - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("1.1.1.1"), - } + b.Run("request", func(b *testing.B) { + benchmarkExtractionLoop(b, func() (Extraction, error) { return extractor.Extract(req) }) + }) - b.ResetTimer() - for i := 0; i < b.N; i++ { - for _, ip := range testIPs { - extractor.isTrustedProxy(ip) - } - } + b.Run("input", func(b *testing.B) { + benchmarkExtractionLoop(b, func() (Extraction, error) { return extractor.ExtractInput(input) }) + }) } -func BenchmarkIsTrustedProxy_LargeCIDRSet_Precomputed(b *testing.B) { - const prefixCount = 4096 - prefixes := make([]netip.Prefix, 0, prefixCount) - for i := 0; i < prefixCount; i++ { - secondOctet := byte((i / 16) % 256) - thirdOctet := byte(i % 256) - prefixes = append(prefixes, netip.PrefixFrom(netip.AddrFrom4([4]byte{10, secondOctet, thirdOctet, 0}), 24)) +func BenchmarkExtract_RequestVsInput_XForwardedFor(b *testing.B) { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor := mustBenchmarkExtractor(b, cfg) + req := &http.Request{RemoteAddr: "127.0.0.1:12345", Header: make(http.Header)} + req.Header.Set("X-Forwarded-For", "1.1.1.1") + inputHTTP := Input{ + Context: context.Background(), + RemoteAddr: "127.0.0.1:12345", + Headers: http.Header{"X-Forwarded-For": {"1.1.1.1"}}, } - - extractor, _ := New( - WithTrustedProxyPrefixes(prefixes...), - WithMinTrustedProxies(0), - WithMaxTrustedProxies(0), - ) - ip := netip.MustParseAddr("10.128.8.8") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - if !extractor.isTrustedProxy(ip) { - b.Fatal("expected trusted proxy") - } + inputFunc := Input{ + Context: context.Background(), + RemoteAddr: "127.0.0.1:12345", + Headers: HeaderValuesFunc(func(name string) []string { + if name == "X-Forwarded-For" { + return []string{"1.1.1.1"} + } + return nil + }), } -} -func BenchmarkIsTrustedProxy_LargeCIDRSet_LinearFallback(b *testing.B) { - const prefixCount = 4096 - prefixes := make([]netip.Prefix, 0, prefixCount) - for i := 0; i < prefixCount; i++ { - secondOctet := byte((i / 16) % 256) - thirdOctet := byte(i % 256) - prefixes = append(prefixes, netip.PrefixFrom(netip.AddrFrom4([4]byte{10, secondOctet, thirdOctet, 0}), 24)) - } + b.Run("request", func(b *testing.B) { + benchmarkExtractionLoop(b, func() (Extraction, error) { return extractor.Extract(req) }) + }) - extractor := &Extractor{config: &config{trustedProxyCIDRs: prefixes}} - ip := netip.MustParseAddr("10.128.8.8") + b.Run("input_http_header", func(b *testing.B) { + benchmarkExtractionLoop(b, func() (Extraction, error) { return extractor.ExtractInput(inputHTTP) }) + }) - b.ResetTimer() - for i := 0; i < b.N; i++ { - if !extractor.isTrustedProxy(ip) { - b.Fatal("expected trusted proxy") - } - } + b.Run("input_header_func", func(b *testing.B) { + benchmarkExtractionLoop(b, func() (Extraction, error) { return extractor.ExtractInput(inputFunc) }) + }) } -func BenchmarkChainAnalysis_Rightmost(b *testing.B) { - cidrs, _ := ParseCIDRs("10.0.0.0/8") - extractor, _ := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(3), - ) - - parts := []string{"1.1.1.1", "8.8.8.8", "10.0.0.1", "10.0.0.2"} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := extractor.analyzeChainRightmost(parts) - if err != nil { - b.Fatal(err) - } +func BenchmarkParseCIDRs(b *testing.B) { + cidrs := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "2606:4700:4700::/32", } -} - -func BenchmarkChainAnalysis_Leftmost(b *testing.B) { - cidrs, _ := ParseCIDRs("10.0.0.0/8") - extractor, _ := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(3), - WithChainSelection(LeftmostUntrustedIP), - ) - - parts := []string{"1.1.1.1", "8.8.8.8", "10.0.0.1", "10.0.0.2"} b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := extractor.analyzeChainLeftmost(parts) + _, err := ParseCIDRs(cidrs...) if err != nil { b.Fatal(err) } diff --git a/call_options.go b/call_options.go deleted file mode 100644 index 32ddf63..0000000 --- a/call_options.go +++ /dev/null @@ -1,107 +0,0 @@ -package clientip - -import ( - "fmt" - "net/netip" - "slices" -) - -// WithCallTrustedProxyPrefixes replaces trusted proxy prefixes for one call. -func WithCallTrustedProxyPrefixes(prefixes ...netip.Prefix) CallOption { - prefixes = slices.Clone(prefixes) - - return func(c *config) error { - normalized, err := normalizeTrustedProxyPrefixes(prefixes) - if err != nil { - return err - } - - c.trustedProxyCIDRs = mergeUniquePrefixes(nil, normalized...) - return nil - } -} - -// WithCallMinTrustedProxies overrides the minimum trusted proxy count for one call. -func WithCallMinTrustedProxies(min int) CallOption { - return func(c *config) error { - c.minTrustedProxies = min - return nil - } -} - -// WithCallMaxTrustedProxies overrides the maximum trusted proxy count for one call. -func WithCallMaxTrustedProxies(max int) CallOption { - return func(c *config) error { - c.maxTrustedProxies = max - return nil - } -} - -// WithCallAllowPrivateIPs overrides private-IP policy for one call. -func WithCallAllowPrivateIPs(allow bool) CallOption { - return func(c *config) error { - c.allowPrivateIPs = allow - return nil - } -} - -// WithCallAllowedReservedClientPrefixes replaces reserved-prefix allowlist for one call. -func WithCallAllowedReservedClientPrefixes(prefixes ...netip.Prefix) CallOption { - prefixes = slices.Clone(prefixes) - - return func(c *config) error { - normalized, err := normalizeReservedClientPrefixes(prefixes) - if err != nil { - return err - } - - c.allowReservedClientPrefixes = mergeUniquePrefixes(nil, normalized...) - return nil - } -} - -// WithCallMaxChainLength overrides max chain length for one call. -func WithCallMaxChainLength(max int) CallOption { - return func(c *config) error { - c.maxChainLength = max - return nil - } -} - -// WithCallChainSelection overrides chain selection mode for one call. -func WithCallChainSelection(selection ChainSelection) CallOption { - return func(c *config) error { - c.chainSelection = selection - return nil - } -} - -// WithCallSecurityMode overrides security mode for one call. -func WithCallSecurityMode(mode SecurityMode) CallOption { - return func(c *config) error { - c.securityMode = mode - return nil - } -} - -// WithCallDebugInfo overrides debug-info output for one call. -func WithCallDebugInfo(enable bool) CallOption { - return func(c *config) error { - c.debugMode = enable - return nil - } -} - -// WithCallSourcePriority overrides source priority for one call. -func WithCallSourcePriority(sources ...Source) CallOption { - sources = canonicalizeSources(slices.Clone(sources)) - - return func(c *config) error { - if len(sources) == 0 { - return fmt.Errorf("at least one source required in WithCallSourcePriority") - } - - c.sourcePriority = slices.Clone(sources) - return nil - } -} diff --git a/chain_analysis.go b/chain_analysis.go deleted file mode 100644 index 4db6e28..0000000 --- a/chain_analysis.go +++ /dev/null @@ -1,233 +0,0 @@ -package clientip - -import ( - "net/netip" - "strings" -) - -type chainAnalysis struct { - clientIndex int - trustedCount int - trustedIndices []int -} - -func (e *Extractor) clientIPFromChainWithDebug(source Source, parts []string) (netip.Addr, int, *ChainDebugInfo, error) { - if len(parts) == 0 { - return netip.Addr{}, 0, nil, &ExtractionError{ - Err: ErrInvalidIP, - Source: source, - } - } - - analysis, clientIP, err := e.analyzeChainForExtraction(parts, e.config.debugMode) - - var debugInfo *ChainDebugInfo - if e.config.debugMode { - debugInfo = &ChainDebugInfo{ - FullChain: parts, - ClientIndex: analysis.clientIndex, - TrustedIndices: analysis.trustedIndices, - } - } - - if err != nil { - chain := strings.Join(parts, ", ") - return netip.Addr{}, analysis.trustedCount, debugInfo, &ProxyValidationError{ - ExtractionError: ExtractionError{ - Err: err, - Source: source, - }, - Chain: chain, - TrustedProxyCount: analysis.trustedCount, - MinTrustedProxies: e.config.minTrustedProxies, - MaxTrustedProxies: e.config.maxTrustedProxies, - } - } - - clientIPStr := parts[analysis.clientIndex] - - if !e.isPlausibleClientIP(clientIP) { - chain := strings.Join(parts, ", ") - return netip.Addr{}, analysis.trustedCount, debugInfo, &InvalidIPError{ - ExtractionError: ExtractionError{ - Err: ErrInvalidIP, - Source: source, - }, - Chain: chain, - ExtractedIP: clientIPStr, - Index: analysis.clientIndex, - TrustedProxies: analysis.trustedCount, - } - } - - return normalizeIP(clientIP), analysis.trustedCount, debugInfo, nil -} - -func (e *Extractor) analyzeChainForExtraction(parts []string, collectTrustedIndices bool) (chainAnalysis, netip.Addr, error) { - if len(parts) == 0 { - return chainAnalysis{}, netip.Addr{}, nil - } - - if e.config.chainSelection == LeftmostUntrustedIP { - return e.analyzeChainLeftmostForExtraction(parts, collectTrustedIndices) - } - return e.analyzeChainRightmostForExtraction(parts, collectTrustedIndices) -} - -func (e *Extractor) analyzeChainRightmost(parts []string) (chainAnalysis, error) { - analysis, _, err := e.analyzeChainRightmostForExtraction(parts, true) - return analysis, err -} - -func (e *Extractor) analyzeChainRightmostForExtraction(parts []string, collectTrustedIndices bool) (chainAnalysis, netip.Addr, error) { - trustedCount := 0 - clientIndex := 0 - clientIP := netip.Addr{} - - var trustedIndices []int - if collectTrustedIndices { - trustedIndices = make([]int, 0, len(parts)) - } - - hasCIDRs := len(e.config.trustedProxyCIDRs) > 0 - - for i := len(parts) - 1; i >= 0; i-- { - if !hasCIDRs && e.config.maxTrustedProxies > 0 && trustedCount >= e.config.maxTrustedProxies { - clientIndex = i - clientIP = parseIP(parts[i]) - break - } - - ip := parseIP(parts[i]) - - if hasCIDRs && !e.isTrustedProxy(ip) { - clientIndex = i - clientIP = ip - break - } - - if collectTrustedIndices { - trustedIndices = append(trustedIndices, i) - } - trustedCount++ - clientIP = ip - } - - analysis := chainAnalysis{ - clientIndex: clientIndex, - trustedCount: trustedCount, - trustedIndices: trustedIndices, - } - - if err := e.validateProxyCount(trustedCount); err != nil { - return analysis, netip.Addr{}, err - } - - return analysis, clientIP, nil -} - -func (e *Extractor) analyzeChainLeftmost(parts []string) (chainAnalysis, error) { - analysis, _, err := e.analyzeChainLeftmostForExtraction(parts, true) - return analysis, err -} - -func (e *Extractor) analyzeChainLeftmostForExtraction(parts []string, collectTrustedIndices bool) (chainAnalysis, netip.Addr, error) { - if len(e.config.trustedProxyCIDRs) == 0 { - analysis := chainAnalysis{clientIndex: 0, trustedCount: 0} - return analysis, parseIP(parts[0]), nil - } - - trustedCount := 0 - leftmostUntrustedIndex := -1 - leftmostUntrustedIP := netip.Addr{} - hasLeftmostUntrusted := false - - fallbackClientIndex := 0 - fallbackClientIP := netip.Addr{} - hasFallbackClient := false - - var trustedIndices []int - if collectTrustedIndices { - trustedIndices = make([]int, 0, len(parts)) - } - - stillTrailingTrusted := true - - for i := len(parts) - 1; i >= 0; i-- { - ip := parseIP(parts[i]) - isTrusted := e.isTrustedProxy(ip) - - if stillTrailingTrusted && isTrusted { - if collectTrustedIndices { - trustedIndices = append(trustedIndices, i) - } - trustedCount++ - continue - } - - if stillTrailingTrusted { - fallbackClientIndex = i - fallbackClientIP = ip - hasFallbackClient = true - } - - stillTrailingTrusted = false - if !isTrusted { - leftmostUntrustedIndex = i - leftmostUntrustedIP = ip - hasLeftmostUntrusted = true - } - } - - analysis := chainAnalysis{ - trustedCount: trustedCount, - } - if collectTrustedIndices { - analysis.trustedIndices = trustedIndices - } - - if err := e.validateProxyCount(trustedCount); err != nil { - return analysis, netip.Addr{}, err - } - - if hasLeftmostUntrusted { - analysis.clientIndex = leftmostUntrustedIndex - return analysis, leftmostUntrustedIP, nil - } - - if hasFallbackClient { - analysis.clientIndex = fallbackClientIndex - return analysis, fallbackClientIP, nil - } - - analysis.clientIndex = 0 - return analysis, parseIP(parts[analysis.clientIndex]), nil -} - -func (e *Extractor) selectLeftmostUntrustedIP(parts []string, trustedProxiesFromRight int) int { - trustedFlags := make([]bool, len(parts)) - for i, part := range parts { - trustedFlags[i] = e.isTrustedProxy(parseIP(part)) - } - - return selectLeftmostUntrustedTrusted(trustedFlags, trustedProxiesFromRight) -} - -func selectLeftmostUntrustedTrusted(trustedFlags []bool, trustedProxiesFromRight int) int { - untrustedPortionEnd := len(trustedFlags) - trustedProxiesFromRight - if untrustedPortionEnd < 0 { - untrustedPortionEnd = 0 - } - - for i := 0; i < untrustedPortionEnd; i++ { - if !trustedFlags[i] { - return i - } - } - - if untrustedPortionEnd <= 0 { - return 0 - } - - return untrustedPortionEnd - 1 -} diff --git a/chain_validation.go b/chain_validation.go deleted file mode 100644 index 37bec6e..0000000 --- a/chain_validation.go +++ /dev/null @@ -1,57 +0,0 @@ -package clientip - -import "net/netip" - -func (e *Extractor) isTrustedProxy(ip netip.Addr) bool { - if !ip.IsValid() { - return false - } - - if e.config.trustedProxyMatch.initialized { - return e.config.trustedProxyMatch.contains(ip) - } - - for _, cidr := range e.config.trustedProxyCIDRs { - if cidr.Contains(ip) { - return true - } - } - - return false -} - -func (e *Extractor) validateProxyCount(trustedCount int) error { - if len(e.config.trustedProxyCIDRs) > 0 && e.config.minTrustedProxies > 0 && trustedCount == 0 { - e.config.metrics.RecordSecurityEvent(securityEventNoTrustedProxies) - return ErrNoTrustedProxies - } - - if e.config.minTrustedProxies > 0 && trustedCount < e.config.minTrustedProxies { - e.config.metrics.RecordSecurityEvent(securityEventTooFewTrustedProxies) - return ErrTooFewTrustedProxies - } - - if e.config.maxTrustedProxies > 0 && trustedCount > e.config.maxTrustedProxies { - e.config.metrics.RecordSecurityEvent(securityEventTooManyTrustedProxies) - return ErrTooManyTrustedProxies - } - - return nil -} - -// appendChainPart appends one parsed chain part while enforcing maxChainLength. -func (e *Extractor) appendChainPart(parts []string, part string, source Source) ([]string, error) { - if len(parts) >= e.config.maxChainLength { - e.config.metrics.RecordSecurityEvent(securityEventChainTooLong) - return nil, &ChainTooLongError{ - ExtractionError: ExtractionError{ - Err: ErrChainTooLong, - Source: source, - }, - ChainLength: len(parts) + 1, - MaxLength: e.config.maxChainLength, - } - } - - return append(parts, part), nil -} diff --git a/classify.go b/classify.go new file mode 100644 index 0000000..80f6dfc --- /dev/null +++ b/classify.go @@ -0,0 +1,60 @@ +package clientip + +import ( + "context" + "errors" +) + +// ResultKind is a coarse-grained classification for extraction and resolution +// results. +// +// ClassifyError returns ResultSuccess for nil and ResultUnknown for non-nil +// errors outside the package's standard extraction and resolution surface. +type ResultKind uint8 + +const ( + // ResultUnknown indicates a non-nil error outside the package's standard + // extraction and resolution categories. + ResultUnknown ResultKind = iota + // ResultSuccess indicates the operation completed without error. + ResultSuccess + // ResultUnavailable indicates the selected source was not present. + ResultUnavailable + // ResultInvalid indicates invalid request input or an invalid client IP. + ResultInvalid + // ResultUntrusted indicates the request failed trusted-proxy validation. + ResultUntrusted + // ResultMalformed indicates malformed or conflicting proxy-header input. + ResultMalformed + // ResultCanceled indicates context cancellation or deadline expiry. + ResultCanceled +) + +// ClassifyError maps the package's detailed error surface into a smaller set of +// policy-oriented result kinds. +// +// This helper is additive: typed errors and errors.Is / errors.As remain the +// detailed interface when callers need source-specific diagnostics. +func ClassifyError(err error) ResultKind { + switch { + case err == nil: + return ResultSuccess + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + return ResultCanceled + case errors.Is(err, ErrSourceUnavailable): + return ResultUnavailable + case errors.Is(err, ErrUntrustedProxy), + errors.Is(err, ErrNoTrustedProxies), + errors.Is(err, ErrTooFewTrustedProxies), + errors.Is(err, ErrTooManyTrustedProxies): + return ResultUntrusted + case errors.Is(err, ErrInvalidForwardedHeader), + errors.Is(err, ErrChainTooLong), + errors.Is(err, ErrMultipleSingleIPHeaders): + return ResultMalformed + case errors.Is(err, ErrInvalidIP), errors.Is(err, ErrNilRequest): + return ResultInvalid + default: + return ResultUnknown + } +} diff --git a/classify_test.go b/classify_test.go new file mode 100644 index 0000000..b886363 --- /dev/null +++ b/classify_test.go @@ -0,0 +1,37 @@ +package clientip + +import ( + "context" + "errors" + "fmt" + "testing" +) + +func TestClassifyError(t *testing.T) { + tests := []struct { + name string + err error + want ResultKind + }{ + {name: "nil", want: ResultSuccess}, + {name: "source unavailable", err: &ExtractionError{Err: ErrSourceUnavailable, Source: SourceRemoteAddr}, want: ResultUnavailable}, + {name: "invalid ip", err: &RemoteAddrError{ExtractionError: ExtractionError{Err: ErrInvalidIP, Source: SourceRemoteAddr}, RemoteAddr: "bad"}, want: ResultInvalid}, + {name: "nil request", err: ErrNilRequest, want: ResultInvalid}, + {name: "untrusted proxy", err: &ProxyValidationError{ExtractionError: ExtractionError{Err: ErrUntrustedProxy, Source: SourceXRealIP}}, want: ResultUntrusted}, + {name: "too few trusted proxies", err: &ProxyValidationError{ExtractionError: ExtractionError{Err: ErrTooFewTrustedProxies, Source: SourceXForwardedFor}}, want: ResultUntrusted}, + {name: "malformed forwarded", err: fmt.Errorf("wrapped: %w", &ExtractionError{Err: ErrInvalidForwardedHeader, Source: SourceForwarded}), want: ResultMalformed}, + {name: "chain too long", err: &ChainTooLongError{ExtractionError: ExtractionError{Err: ErrChainTooLong, Source: SourceXForwardedFor}, ChainLength: 101, MaxLength: 100}, want: ResultMalformed}, + {name: "multiple single-ip headers", err: &MultipleHeadersError{ExtractionError: ExtractionError{Err: ErrMultipleSingleIPHeaders, Source: SourceXRealIP}, HeaderCount: 2}, want: ResultMalformed}, + {name: "canceled", err: context.Canceled, want: ResultCanceled}, + {name: "deadline exceeded", err: context.DeadlineExceeded, want: ResultCanceled}, + {name: "unknown", err: errors.New("boom"), want: ResultUnknown}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ClassifyError(tt.err); got != tt.want { + t.Fatalf("ClassifyError(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} diff --git a/config.go b/config.go index 44cd266..c366130 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package clientip import ( "fmt" "net/netip" + "reflect" "slices" ) @@ -47,51 +48,67 @@ func (s ChainSelection) valid() bool { return s == RightmostUntrustedIP || s == LeftmostUntrustedIP } -// SecurityMode controls fallback behavior after security-significant errors. -type SecurityMode int - -const ( - // SecurityModeStrict fails closed and stops on security-significant errors. - SecurityModeStrict SecurityMode = iota + 1 - // SecurityModeLax allows fallback to lower-priority sources after such errors. - SecurityModeLax -) +// Config configures an Extractor. +type Config struct { + TrustedProxyPrefixes []netip.Prefix + MinTrustedProxies int + MaxTrustedProxies int + AllowPrivateIPs bool + AllowedReservedClientPrefixes []netip.Prefix + MaxChainLength int + ChainSelection ChainSelection + DebugInfo bool + Sources []Source + Logger Logger + Metrics Metrics +} -// String returns the canonical text representation of m. -func (m SecurityMode) String() string { - switch m { - case SecurityModeStrict: - return "strict" - case SecurityModeLax: - return "lax" - default: - return "unknown" +// DefaultConfig returns the default extractor configuration. +func DefaultConfig() Config { + return Config{ + MaxChainLength: DefaultMaxChainLength, + ChainSelection: RightmostUntrustedIP, + Sources: []Source{builtinSource(sourceRemoteAddr)}, } } -// valid reports whether m is a supported security mode. -func (m SecurityMode) valid() bool { - return m == SecurityModeStrict || m == SecurityModeLax +// LoopbackProxyPrefixes returns loopback CIDRs commonly used when the app sits +// behind a reverse proxy on the same host. +func LoopbackProxyPrefixes() []netip.Prefix { + return clonePrefixes(loopbackProxyCIDRs) +} + +// PrivateProxyPrefixes returns private-network CIDRs commonly used for trusted +// upstream proxies in VM and internal network deployments. +func PrivateProxyPrefixes() []netip.Prefix { + return clonePrefixes(privateProxyCIDRs) +} + +// LocalProxyPrefixes returns loopback and private-network proxy CIDRs. +func LocalProxyPrefixes() []netip.Prefix { + return mergeUniquePrefixes(clonePrefixes(loopbackProxyCIDRs), privateProxyCIDRs...) +} + +// ProxyPrefixesFromAddrs converts individual proxy addresses into host-sized +// trusted prefixes. +func ProxyPrefixesFromAddrs(addrs ...netip.Addr) ([]netip.Prefix, error) { + prefixes := make([]netip.Prefix, 0, len(addrs)) + for _, addr := range addrs { + if !addr.IsValid() { + return nil, fmt.Errorf("invalid proxy address %q", addr) + } + + addr = normalizeIP(addr) + prefixes = append(prefixes, netip.PrefixFrom(addr, addr.BitLen())) + } + + return prefixes, nil } -// Option configures an Extractor. -// -// Construct options using package-provided option builder functions. -type Option func(*config) error - -// CallOption configures one Extract/ExtractFrom call. -// -// Call options can override policy fields for a single extraction, while -// logger and metrics remain fixed at extractor construction time. -type CallOption func(*config) error - -// config holds extractor configuration state. -// -// It is mutated by Option functions during construction and by CallOption -// functions during per-call policy adjustments. +// config holds normalized runtime configuration state. type config struct { trustedProxyCIDRs []netip.Prefix - trustedProxyMatch trustedProxyMatcher + trustedProxyMatch prefixMatcher minTrustedProxies int maxTrustedProxies int @@ -99,7 +116,6 @@ type config struct { allowReservedClientPrefixes []netip.Prefix maxChainLength int chainSelection ChainSelection - securityMode SecurityMode debugMode bool sourcePriority []Source @@ -107,9 +123,90 @@ type config struct { logger Logger metrics Metrics +} + +func (c *config) validate() error { + if c.minTrustedProxies < 0 { + return fmt.Errorf("minTrustedProxies must be >= 0, got %d", c.minTrustedProxies) + } + if c.maxTrustedProxies < 0 { + return fmt.Errorf("maxTrustedProxies must be >= 0, got %d", c.maxTrustedProxies) + } + if c.maxTrustedProxies > 0 && c.minTrustedProxies > c.maxTrustedProxies { + return fmt.Errorf("minTrustedProxies (%d) cannot exceed maxTrustedProxies (%d)", c.minTrustedProxies, c.maxTrustedProxies) + } + if c.minTrustedProxies > 0 && len(c.trustedProxyCIDRs) == 0 { + return fmt.Errorf("minTrustedProxies > 0 requires TrustedProxyPrefixes to be configured for security validation; to skip validation and trust all proxies, set TrustedProxyPrefixes to 0.0.0.0/0 and ::/0") + } + if c.maxChainLength <= 0 { + return fmt.Errorf("maxChainLength must be > 0, got %d", c.maxChainLength) + } + if !c.chainSelection.valid() { + return fmt.Errorf("invalid chain selection %d (must be RightmostUntrustedIP=1 or LeftmostUntrustedIP=2)", c.chainSelection) + } + if len(c.sourcePriority) == 0 { + return fmt.Errorf("at least one source required in priority list") + } + + hasHeaderSource, hasChainSource, err := c.validateSourcePriority() + if err != nil { + return err + } - metricsFactory func() (Metrics, error) - useMetricsFactory bool + if hasChainSource && c.chainSelection == LeftmostUntrustedIP && len(c.trustedProxyCIDRs) == 0 { + return fmt.Errorf("LeftmostUntrustedIP selection requires trusted proxy prefixes to be configured; without trusted-proxy validation, this selection provides no security benefit over RightmostUntrustedIP") + } + + if hasHeaderSource && len(c.trustedProxyCIDRs) == 0 { + return fmt.Errorf("header-based sources require trusted proxy prefixes; configure TrustedProxyPrefixes directly or use LoopbackProxyPrefixes, PrivateProxyPrefixes, LocalProxyPrefixes, or ProxyPrefixesFromAddrs") + } + + if isNilValue(c.logger) { + return fmt.Errorf("logger cannot be nil") + } + if isNilValue(c.metrics) { + return fmt.Errorf("metrics cannot be nil") + } + return nil +} + +func (c *config) validateSourcePriority() (hasHeaderSource, hasChainSource bool, err error) { + seen := make(map[Source]struct{}, len(c.sourcePriority)) + seenForwarded := false + seenXFF := false + + for _, source := range c.sourcePriority { + source = canonicalSource(source) + if !source.valid() { + return false, false, fmt.Errorf("source names cannot be empty") + } + + if _, ok := seen[source]; ok { + return false, false, fmt.Errorf("duplicate source %q in priority list", source) + } + seen[source] = struct{}{} + + switch source.kind { + case sourceStaticFallback: + return false, false, fmt.Errorf("source %q is resolver-only and cannot be used in Config.Sources", source) + case sourceForwarded: + seenForwarded = true + hasChainSource = true + hasHeaderSource = true + case sourceXForwardedFor: + seenXFF = true + hasChainSource = true + hasHeaderSource = true + case sourceXRealIP, sourceHeader: + hasHeaderSource = true + } + } + + if seenForwarded && seenXFF { + return false, false, fmt.Errorf("priority cannot include both %q and %q; choose one proxy chain header", builtinSource(sourceForwarded), builtinSource(sourceXForwardedFor)) + } + + return hasHeaderSource, hasChainSource, nil } var ( @@ -142,16 +239,22 @@ func clonePrefixes(prefixes []netip.Prefix) []netip.Prefix { return slices.Clone(prefixes) } -func cloneAddrs(addrs []netip.Addr) []netip.Addr { - return slices.Clone(addrs) -} - -func cloneStrings(values []string) []string { +func cloneSources(values []Source) []Source { return slices.Clone(values) } -func cloneSources(values []Source) []Source { - return slices.Clone(values) +func isNilValue(v any) bool { + if v == nil { + return true + } + + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: + return rv.IsNil() + default: + return false + } } func normalizePrefixes(prefixes []netip.Prefix, kind string) ([]netip.Prefix, error) { @@ -201,173 +304,67 @@ func mergeUniquePrefixes(existing []netip.Prefix, additions ...netip.Prefix) []n return merged } -func appendTrustedProxyCIDRs(c *config, prefixes ...netip.Prefix) { - if len(prefixes) == 0 { - return - } - - c.trustedProxyCIDRs = mergeUniquePrefixes(c.trustedProxyCIDRs, prefixes...) -} - func defaultConfig() *config { + defaults := DefaultConfig() return &config{ - minTrustedProxies: 0, - maxTrustedProxies: 0, - allowPrivateIPs: false, - maxChainLength: DefaultMaxChainLength, - chainSelection: RightmostUntrustedIP, - securityMode: SecurityModeStrict, + minTrustedProxies: defaults.MinTrustedProxies, + maxTrustedProxies: defaults.MaxTrustedProxies, + allowPrivateIPs: defaults.AllowPrivateIPs, + maxChainLength: defaults.MaxChainLength, + chainSelection: defaults.ChainSelection, logger: noopLogger{}, metrics: noopMetrics{}, - sourcePriority: []Source{ - builtinSource(sourceRemoteAddr), - }, - } -} - -func applyOptions(c *config, opts ...Option) error { - for _, opt := range opts { - if err := opt(c); err != nil { - return err - } + sourcePriority: cloneSources(defaults.Sources), } - - return nil } -func configFromOptions(opts ...Option) (*config, error) { +func configFromPublic(public Config) (*config, error) { cfg := defaultConfig() - if err := applyOptions(cfg, opts...); err != nil { - return nil, err - } - - cfg.sourceHeaderKeys = sourceHeaderKeys(cfg.sourcePriority) - - appendTrustedProxyCIDRs(cfg, cfg.trustedProxyCIDRs...) - cfg.trustedProxyMatch = buildTrustedProxyMatcher(cfg.trustedProxyCIDRs) - - if cfg.useMetricsFactory { - if cfg.metricsFactory == nil { - return nil, fmt.Errorf("metrics factory cannot be nil") - } - } - - validationConfig := cfg - if cfg.useMetricsFactory { - validationConfig = cfg.clone() - validationConfig.metrics = noopMetrics{} - } - - if err := validationConfig.validate(); err != nil { - return nil, err - } - - if cfg.useMetricsFactory { - metrics, err := cfg.metricsFactory() + if public.TrustedProxyPrefixes != nil { + normalized, err := normalizeTrustedProxyPrefixes(public.TrustedProxyPrefixes) if err != nil { return nil, err } - if isNilValue(metrics) { - return nil, fmt.Errorf("metrics cannot be nil") - } - cfg.metrics = metrics + cfg.trustedProxyCIDRs = mergeUniquePrefixes(nil, normalized...) + } - if err := cfg.validate(); err != nil { + if public.AllowedReservedClientPrefixes != nil { + normalized, err := normalizeReservedClientPrefixes(public.AllowedReservedClientPrefixes) + if err != nil { return nil, err } + cfg.allowReservedClientPrefixes = mergeUniquePrefixes(nil, normalized...) } - return cfg, nil -} - -func (c *config) clone() *config { - return &config{ - trustedProxyCIDRs: clonePrefixes(c.trustedProxyCIDRs), - trustedProxyMatch: c.trustedProxyMatch, - minTrustedProxies: c.minTrustedProxies, - maxTrustedProxies: c.maxTrustedProxies, - allowPrivateIPs: c.allowPrivateIPs, - allowReservedClientPrefixes: clonePrefixes(c.allowReservedClientPrefixes), - maxChainLength: c.maxChainLength, - chainSelection: c.chainSelection, - securityMode: c.securityMode, - debugMode: c.debugMode, - sourcePriority: cloneSources(c.sourcePriority), - sourceHeaderKeys: cloneStrings(c.sourceHeaderKeys), - logger: c.logger, - metrics: c.metrics, - metricsFactory: c.metricsFactory, - useMetricsFactory: c.useMetricsFactory, + if public.MaxChainLength != 0 { + cfg.maxChainLength = public.MaxChainLength } -} - -func (c *config) samePolicy(other *config) bool { - if other == nil { - return false + if public.ChainSelection != 0 { + cfg.chainSelection = public.ChainSelection } - - return slices.Equal(c.trustedProxyCIDRs, other.trustedProxyCIDRs) && - c.minTrustedProxies == other.minTrustedProxies && - c.maxTrustedProxies == other.maxTrustedProxies && - c.allowPrivateIPs == other.allowPrivateIPs && - slices.Equal(c.allowReservedClientPrefixes, other.allowReservedClientPrefixes) && - c.maxChainLength == other.maxChainLength && - c.chainSelection == other.chainSelection && - c.securityMode == other.securityMode && - c.debugMode == other.debugMode && - slices.Equal(c.sourcePriority, other.sourcePriority) && - slices.Equal(c.sourceHeaderKeys, other.sourceHeaderKeys) -} - -func applyCallOptions(c *config, callOpts ...CallOption) error { - for _, callOpt := range callOpts { - if callOpt == nil { - return fmt.Errorf("call option cannot be nil") - } - - if err := callOpt(c); err != nil { - return err - } + if public.Sources != nil { + cfg.sourcePriority = canonicalizeSources(cloneSources(public.Sources)) } - return nil -} + cfg.minTrustedProxies = public.MinTrustedProxies + cfg.maxTrustedProxies = public.MaxTrustedProxies + cfg.allowPrivateIPs = public.AllowPrivateIPs + cfg.debugMode = public.DebugInfo -func (c *config) withCallOptions(callOpts ...CallOption) (*config, error) { - if len(callOpts) == 0 { - return c, nil + if public.Logger != nil { + cfg.logger = public.Logger } - - effective := c.clone() - - if err := applyCallOptions(effective, callOpts...); err != nil { - return nil, err + if public.Metrics != nil { + cfg.metrics = public.Metrics } - sourcePriorityChanged := !slices.Equal(effective.sourcePriority, c.sourcePriority) - if sourcePriorityChanged { - effective.sourcePriority = canonicalizeSources(effective.sourcePriority) - effective.sourceHeaderKeys = sourceHeaderKeys(effective.sourcePriority) - } - - trustedProxyCIDRsChanged := !slices.Equal(effective.trustedProxyCIDRs, c.trustedProxyCIDRs) - if trustedProxyCIDRsChanged { - appendTrustedProxyCIDRs(effective, effective.trustedProxyCIDRs...) - trustedProxyCIDRsChanged = !slices.Equal(effective.trustedProxyCIDRs, c.trustedProxyCIDRs) - } - - if trustedProxyCIDRsChanged { - effective.trustedProxyMatch = buildTrustedProxyMatcher(effective.trustedProxyCIDRs) - } + cfg.sourceHeaderKeys = sourceHeaderKeys(cfg.sourcePriority) + cfg.trustedProxyMatch = newPrefixMatcher(cfg.trustedProxyCIDRs) - if err := effective.validate(); err != nil { + if err := cfg.validate(); err != nil { return nil, err } - if c.samePolicy(effective) { - return c, nil - } - - return effective, nil + return cfg, nil } diff --git a/config_test.go b/config_test.go index a4b204e..afd892a 100644 --- a/config_test.go +++ b/config_test.go @@ -26,7 +26,6 @@ type configSnapshot struct { AllowReservedPrefixes []string MaxChainLength int ChainSelection ChainSelection - SecurityMode SecurityMode DebugMode bool SourcePriority []string } @@ -45,7 +44,6 @@ func snapshotConfig(cfg *config) configSnapshot { AllowReservedPrefixes: cidrStrings(cfg.allowReservedClientPrefixes), MaxChainLength: cfg.maxChainLength, ChainSelection: cfg.chainSelection, - SecurityMode: cfg.securityMode, DebugMode: cfg.debugMode, SourcePriority: sourceNames(cfg.sourcePriority), } @@ -70,12 +68,15 @@ func sourceNames(sources []Source) []string { func TestNew_ConfigScenarios(t *testing.T) { tests := []struct { name string - opts []Option + buildConfig func() Config want configSnapshot wantErrText string }{ { name: "default", + buildConfig: func() Config { + return DefaultConfig() + }, want: configSnapshot{ TrustedProxyCIDRs: []string{}, MinTrustedProxies: 0, @@ -84,27 +85,27 @@ func TestNew_ConfigScenarios(t *testing.T) { AllowReservedPrefixes: []string{}, MaxChainLength: DefaultMaxChainLength, ChainSelection: RightmostUntrustedIP, - SecurityMode: SecurityModeStrict, DebugMode: false, SourcePriority: []string{SourceRemoteAddr.String()}, }, }, { - name: "configured options", - opts: []Option{ - WithTrustedProxyPrefixes( + name: "configured config", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = []netip.Prefix{ netip.MustParsePrefix("10.0.0.0/8"), netip.MustParsePrefix("172.16.0.0/12"), - ), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(3), - WithAllowPrivateIPs(true), - WithAllowedReservedClientPrefixes(netip.MustParsePrefix("198.51.100.0/24")), - WithMaxChainLength(42), - WithChainSelection(LeftmostUntrustedIP), - WithSecurityMode(SecurityModeLax), - WithDebugInfo(true), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), + } + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 3 + cfg.AllowPrivateIPs = true + cfg.AllowedReservedClientPrefixes = []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")} + cfg.MaxChainLength = 42 + cfg.ChainSelection = LeftmostUntrustedIP + cfg.DebugInfo = true + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + return cfg }, want: configSnapshot{ TrustedProxyCIDRs: []string{"10.0.0.0/8", "172.16.0.0/12"}, @@ -114,17 +115,17 @@ func TestNew_ConfigScenarios(t *testing.T) { AllowReservedPrefixes: []string{"198.51.100.0/24"}, MaxChainLength: 42, ChainSelection: LeftmostUntrustedIP, - SecurityMode: SecurityModeLax, DebugMode: true, SourcePriority: []string{SourceXForwardedFor.String(), SourceRemoteAddr.String()}, }, }, { name: "merge option fragments", - opts: []Option{ - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - WithSecurityMode(SecurityModeLax), + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + return cfg }, want: configSnapshot{ TrustedProxyCIDRs: []string{"127.0.0.0/8", "::1/128"}, @@ -134,16 +135,20 @@ func TestNew_ConfigScenarios(t *testing.T) { AllowReservedPrefixes: []string{}, MaxChainLength: DefaultMaxChainLength, ChainSelection: RightmostUntrustedIP, - SecurityMode: SecurityModeLax, DebugMode: false, SourcePriority: []string{SourceXForwardedFor.String(), SourceRemoteAddr.String()}, }, }, { name: "merge trusted proxy prefixes", - opts: []Option{ - WithTrustedProxyPrefixes(netip.MustParsePrefix("10.0.0.0/8")), - WithTrustedProxyPrefixes(netip.MustParsePrefix("10.0.0.0/8"), netip.MustParsePrefix("172.16.0.0/12")), + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("172.16.0.0/12"), + } + return cfg }, want: configSnapshot{ TrustedProxyCIDRs: []string{"10.0.0.0/8", "172.16.0.0/12"}, @@ -153,16 +158,20 @@ func TestNew_ConfigScenarios(t *testing.T) { AllowReservedPrefixes: []string{}, MaxChainLength: DefaultMaxChainLength, ChainSelection: RightmostUntrustedIP, - SecurityMode: SecurityModeStrict, DebugMode: false, SourcePriority: []string{SourceRemoteAddr.String()}, }, }, { name: "merge reserved client prefixes", - opts: []Option{ - WithAllowedReservedClientPrefixes(netip.MustParsePrefix("198.51.100.0/24")), - WithAllowedReservedClientPrefixes(netip.MustParsePrefix("198.51.100.0/24"), netip.MustParsePrefix("203.0.113.0/24")), + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.AllowedReservedClientPrefixes = []netip.Prefix{ + netip.MustParsePrefix("198.51.100.0/24"), + netip.MustParsePrefix("198.51.100.0/24"), + netip.MustParsePrefix("203.0.113.0/24"), + } + return cfg }, want: configSnapshot{ TrustedProxyCIDRs: []string{}, @@ -172,26 +181,33 @@ func TestNew_ConfigScenarios(t *testing.T) { AllowReservedPrefixes: []string{"198.51.100.0/24", "203.0.113.0/24"}, MaxChainLength: DefaultMaxChainLength, ChainSelection: RightmostUntrustedIP, - SecurityMode: SecurityModeStrict, DebugMode: false, SourcePriority: []string{SourceRemoteAddr.String()}, }, }, { - name: "invalid trusted prefix helper", - opts: []Option{WithTrustedProxyPrefixes(netip.Prefix{})}, + name: "invalid trusted prefix helper", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = []netip.Prefix{{}} + return cfg + }, wantErrText: "invalid trusted proxy prefix", }, { - name: "invalid allow reserved prefix helper", - opts: []Option{WithAllowedReservedClientPrefixes(netip.Prefix{})}, + name: "invalid allow reserved prefix helper", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.AllowedReservedClientPrefixes = []netip.Prefix{{}} + return cfg + }, wantErrText: "invalid reserved client prefix", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - extractor, err := New(tt.opts...) + extractor, err := New(tt.buildConfig()) if tt.wantErrText != "" { if err == nil { t.Fatalf("New() error = nil, want containing %q", tt.wantErrText) @@ -213,81 +229,96 @@ func TestNew_ConfigScenarios(t *testing.T) { } } -func TestNew_InvalidOptions(t *testing.T) { +func TestNew_InvalidConfig(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) tests := []struct { name string - opts []Option + buildConfig func() Config wantErrText string }{ { - name: "min exceeds max", - opts: []Option{WithTrustedProxyPrefixes(netip.MustParsePrefix("10.0.0.0/8")), WithMinTrustedProxies(5), WithMaxTrustedProxies(2)}, + name: "min exceeds max", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")} + cfg.MinTrustedProxies = 5 + cfg.MaxTrustedProxies = 2 + return cfg + }, wantErrText: "minTrustedProxies", }, { - name: "header source without trusted proxies", - opts: []Option{WithSourcePriority(SourceXForwardedFor)}, + name: "header source without trusted proxies", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.Sources = []Source{SourceXForwardedFor} + return cfg + }, wantErrText: "header-based sources require trusted proxy prefixes", }, { - name: "invalid chain selection", - opts: []Option{WithChainSelection(ChainSelection(999))}, + name: "invalid chain selection", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.ChainSelection = ChainSelection(999) + return cfg + }, wantErrText: "invalid chain selection", }, { - name: "invalid security mode", - opts: []Option{WithSecurityMode(SecurityMode(999))}, - wantErrText: "invalid security mode", - }, - { - name: "empty explicit source priority", - opts: []Option{WithSourcePriority()}, + name: "empty explicit source priority", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.Sources = []Source{} + return cfg + }, wantErrText: "at least one source required", }, { - name: "typed nil logger", - opts: []Option{WithLogger((*slog.Logger)(nil))}, + name: "typed nil logger", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.Logger = (*slog.Logger)(nil) + return cfg + }, wantErrText: "logger cannot be nil", }, { - name: "typed nil metrics", - opts: []Option{WithMetrics((*testTypedNilMetrics)(nil))}, + name: "typed nil metrics", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.Metrics = (*testTypedNilMetrics)(nil) + return cfg + }, wantErrText: "metrics cannot be nil", }, { - name: "invalid trust proxy addr helper", - opts: []Option{WithTrustedProxyAddrs(netip.Addr{})}, - wantErrText: "invalid proxy address", - }, - { - name: "leftmost without trusted proxies", - opts: []Option{WithSourcePriority(SourceXForwardedFor), WithChainSelection(LeftmostUntrustedIP)}, + name: "leftmost without trusted proxies", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.Sources = []Source{SourceXForwardedFor} + cfg.ChainSelection = LeftmostUntrustedIP + return cfg + }, wantErrText: "LeftmostUntrustedIP selection requires trusted proxy prefixes", }, { - name: "multiple chain sources", - opts: []Option{WithTrustedProxyPrefixes(netip.MustParsePrefix("10.0.0.0/8")), WithSourcePriority(SourceForwarded, SourceXForwardedFor), WithLogger(logger)}, + name: "multiple chain sources", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")} + cfg.Sources = []Source{SourceForwarded, SourceXForwardedFor} + cfg.Logger = logger + return cfg + }, wantErrText: "priority cannot include both", }, - { - name: "nil metrics factory", - opts: []Option{WithMetricsFactory(nil)}, - wantErrText: "metrics factory cannot be nil", - }, - { - name: "typed nil metrics from factory", - opts: []Option{WithMetricsFactory(func() (Metrics, error) { - return (*testTypedNilMetrics)(nil), nil - })}, - wantErrText: "metrics cannot be nil", - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := New(tt.opts...) + _, err := New(tt.buildConfig()) if err == nil { t.Fatalf("New() error = nil, want containing %q", tt.wantErrText) } @@ -298,439 +329,83 @@ func TestNew_InvalidOptions(t *testing.T) { } } -func TestNew_WithMetricsFactory_Lifecycle(t *testing.T) { - t.Run("factory not called when configuration invalid", func(t *testing.T) { - calls := 0 - - _, err := New( - WithMetricsFactory(func() (Metrics, error) { - calls++ - return noopMetrics{}, nil - }), - WithSourcePriority(SourceXForwardedFor), - ) - if err == nil { - t.Fatal("New() error = nil, want non-nil") - } - if calls != 0 { - t.Fatalf("metrics factory calls = %d, want 0", calls) - } - }) - - t.Run("factory called once when configuration valid", func(t *testing.T) { - calls := 0 - - _, err := New( - WithMetricsFactory(func() (Metrics, error) { - calls++ - return noopMetrics{}, nil - }), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - if calls != 1 { - t.Fatalf("metrics factory calls = %d, want 1", calls) - } - }) - - t.Run("WithMetrics after factory disables factory", func(t *testing.T) { - calls := 0 - - _, err := New( - WithMetricsFactory(func() (Metrics, error) { - calls++ - return noopMetrics{}, nil - }), - WithMetrics(noopMetrics{}), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - if calls != 0 { - t.Fatalf("metrics factory calls = %d, want 0", calls) - } - }) - - t.Run("factory last overrides prior metrics value", func(t *testing.T) { - calls := 0 - - _, err := New( - WithMetrics((*testTypedNilMetrics)(nil)), - WithMetricsFactory(func() (Metrics, error) { - calls++ - return noopMetrics{}, nil - }), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - if calls != 1 { - t.Fatalf("metrics factory calls = %d, want 1", calls) - } - }) -} - -func TestConfig_WithCallOptions(t *testing.T) { - base, err := configFromOptions( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("configFromOptions() error = %v", err) - } - +func TestNew_InvalidBoundsAndDuplicatePriority(t *testing.T) { tests := []struct { name string - callOpts []CallOption - want configSnapshot + buildConfig func() Config wantSources []Source wantErrText string }{ { - name: "last wins on scalar", - callOpts: []CallOption{ - WithCallSecurityMode(SecurityModeLax), - WithCallSecurityMode(SecurityModeStrict), - }, - want: configSnapshot{ - TrustedProxyCIDRs: []string{"127.0.0.0/8", "::1/128"}, - MinTrustedProxies: 0, - MaxTrustedProxies: 0, - AllowPrivateIPs: false, - AllowReservedPrefixes: []string{}, - MaxChainLength: DefaultMaxChainLength, - ChainSelection: RightmostUntrustedIP, - SecurityMode: SecurityModeStrict, - DebugMode: false, - SourcePriority: []string{SourceXForwardedFor.String(), SourceRemoteAddr.String()}, + name: "negative min trusted proxies", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.MinTrustedProxies = -1 + return cfg }, + wantErrText: "minTrustedProxies must be >= 0", }, { - name: "call option source priority", - callOpts: []CallOption{WithCallSourcePriority(SourceRemoteAddr)}, - want: configSnapshot{ - TrustedProxyCIDRs: []string{"127.0.0.0/8", "::1/128"}, - MinTrustedProxies: 0, - MaxTrustedProxies: 0, - AllowPrivateIPs: false, - AllowReservedPrefixes: []string{}, - MaxChainLength: DefaultMaxChainLength, - ChainSelection: RightmostUntrustedIP, - SecurityMode: SecurityModeStrict, - DebugMode: false, - SourcePriority: []string{SourceRemoteAddr.String()}, + name: "negative max trusted proxies", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.MaxTrustedProxies = -1 + return cfg }, + wantErrText: "maxTrustedProxies must be >= 0", }, { - name: "call option trusted prefixes normalize and dedupe", - callOpts: []CallOption{WithCallTrustedProxyPrefixes( - netip.MustParsePrefix("10.0.0.1/8"), - netip.MustParsePrefix("10.0.0.2/8"), - )}, - want: configSnapshot{ - TrustedProxyCIDRs: []string{"10.0.0.0/8"}, - MinTrustedProxies: 0, - MaxTrustedProxies: 0, - AllowPrivateIPs: false, - AllowReservedPrefixes: []string{}, - MaxChainLength: DefaultMaxChainLength, - ChainSelection: RightmostUntrustedIP, - SecurityMode: SecurityModeStrict, - DebugMode: false, - SourcePriority: []string{SourceXForwardedFor.String(), SourceRemoteAddr.String()}, + name: "negative max chain length", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.MaxChainLength = -1 + return cfg }, + wantErrText: "maxChainLength must be > 0", }, { - name: "call option reserved allowlist", - callOpts: []CallOption{WithCallAllowedReservedClientPrefixes( - netip.MustParsePrefix("198.51.100.10/24"), - netip.MustParsePrefix("198.51.100.0/24"), - )}, - want: configSnapshot{ - TrustedProxyCIDRs: []string{"127.0.0.0/8", "::1/128"}, - MinTrustedProxies: 0, - MaxTrustedProxies: 0, - AllowPrivateIPs: false, - AllowReservedPrefixes: []string{"198.51.100.0/24"}, - MaxChainLength: DefaultMaxChainLength, - ChainSelection: RightmostUntrustedIP, - SecurityMode: SecurityModeStrict, - DebugMode: false, - SourcePriority: []string{SourceXForwardedFor.String(), SourceRemoteAddr.String()}, + name: "duplicate source in priority", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.Sources = []Source{SourceRemoteAddr, SourceRemoteAddr} + return cfg }, - }, - { - name: "invalid empty source priority", - callOpts: []CallOption{WithCallSourcePriority()}, - wantErrText: "at least one source required", - }, - { - name: "duplicate built-in alias after canonicalization", - callOpts: []CallOption{WithCallSourcePriority(SourceXForwardedFor, HeaderSource("X-Forwarded-For"))}, wantErrText: "duplicate source", }, { - name: "distinct custom headers with different runtime keys are allowed", - callOpts: []CallOption{WithCallSourcePriority(HeaderSource("Foo-Bar"), HeaderSource("Foo_Bar"))}, - want: configSnapshot{ - TrustedProxyCIDRs: []string{"127.0.0.0/8", "::1/128"}, - MinTrustedProxies: 0, - MaxTrustedProxies: 0, - AllowPrivateIPs: false, - AllowReservedPrefixes: []string{}, - MaxChainLength: DefaultMaxChainLength, - ChainSelection: RightmostUntrustedIP, - SecurityMode: SecurityModeStrict, - DebugMode: false, - SourcePriority: []string{"foo_bar", "foo_bar"}, + name: "duplicate source after canonicalization", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXForwardedFor, HeaderSource("X-Forwarded-For")} + return cfg }, - wantSources: []Source{HeaderSource("Foo-Bar"), HeaderSource("Foo_Bar")}, - }, - { - name: "invalid trusted prefix call option", - callOpts: []CallOption{WithCallTrustedProxyPrefixes(netip.Prefix{})}, - wantErrText: "invalid trusted proxy prefix", - }, - { - name: "invalid reserved prefix call option", - callOpts: []CallOption{WithCallAllowedReservedClientPrefixes(netip.Prefix{})}, - wantErrText: "invalid reserved client prefix", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - effective, err := base.withCallOptions(tt.callOpts...) - if tt.wantErrText != "" { - if err == nil { - t.Fatalf("withCallOptions() error = nil, want containing %q", tt.wantErrText) - } - if !strings.Contains(err.Error(), tt.wantErrText) { - t.Fatalf("withCallOptions() error = %q, want containing %q", err.Error(), tt.wantErrText) - } - return - } - - if err != nil { - t.Fatalf("withCallOptions() error = %v", err) - } - - if diff := cmp.Diff(tt.want, snapshotConfig(effective)); diff != "" { - t.Fatalf("call-option config mismatch (-want +got):\n%s", diff) - } - - if tt.wantSources != nil { - if diff := cmp.Diff(tt.wantSources, effective.sourcePriority); diff != "" { - t.Fatalf("source priority mismatch (-want +got):\n%s", diff) - } - } - }) - } -} - -func TestApplyCallOptions(t *testing.T) { - t.Run("applies options in order", func(t *testing.T) { - cfg := defaultConfig() - - err := applyCallOptions( - cfg, - WithCallAllowPrivateIPs(true), - WithCallSecurityMode(SecurityModeLax), - WithCallSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("applyCallOptions() error = %v", err) - } - - got := struct { - AllowPrivateIPs bool - SecurityMode SecurityMode - SourcePriority []string - }{ - AllowPrivateIPs: cfg.allowPrivateIPs, - SecurityMode: cfg.securityMode, - SourcePriority: sourceNames(cfg.sourcePriority), - } - - want := struct { - AllowPrivateIPs bool - SecurityMode SecurityMode - SourcePriority []string - }{ - AllowPrivateIPs: true, - SecurityMode: SecurityModeLax, - SourcePriority: []string{SourceXForwardedFor.String(), SourceRemoteAddr.String()}, - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Fatalf("applyCallOptions() mismatch (-want +got):\n%s", diff) - } - }) - - t.Run("nil option returns error", func(t *testing.T) { - err := applyCallOptions(defaultConfig(), nil) - if err == nil { - t.Fatal("applyCallOptions() error = nil, want non-nil") - } - }) -} - -func TestConfig_WithCallOptions_None_ReturnsBase(t *testing.T) { - base, err := configFromOptions( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("configFromOptions() error = %v", err) - } - - effective, err := base.withCallOptions() - if err != nil { - t.Fatalf("withCallOptions() error = %v", err) - } - - if effective != base { - t.Fatal("withCallOptions() should return original config when no call options are provided") - } -} - -func TestConfig_WithCallOptions_NoEffectiveChanges_ReturnsBase(t *testing.T) { - base, err := configFromOptions( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - WithSecurityMode(SecurityModeStrict), - ) - if err != nil { - t.Fatalf("configFromOptions() error = %v", err) - } - - effective, err := base.withCallOptions( - WithCallSecurityMode(SecurityModeStrict), - WithCallSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("withCallOptions() error = %v", err) - } - - if effective != base { - t.Fatal("withCallOptions() should return original config when call options do not change policy") - } -} - -func TestConfig_WithCallOptions_PreservesTrustedProxyMatcherWithoutPrefixOverride(t *testing.T) { - base, err := configFromOptions( - WithTrustedProxyPrefixes(netip.MustParsePrefix("10.0.0.0/8"), netip.MustParsePrefix("2001:db8::/32")), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("configFromOptions() error = %v", err) - } - - if !base.trustedProxyMatch.initialized { - t.Fatal("base matcher should be initialized") - } - - effective, err := base.withCallOptions(WithCallSecurityMode(SecurityModeLax)) - if err != nil { - t.Fatalf("withCallOptions() error = %v", err) - } - - if effective == base { - t.Fatal("withCallOptions() should return a cloned config when call options are set") - } - - if effective.trustedProxyMatch.ipv4Root != base.trustedProxyMatch.ipv4Root { - t.Fatal("withCallOptions() should preserve the existing IPv4 matcher when trusted proxy prefixes are unchanged") - } - - if effective.trustedProxyMatch.ipv6Root != base.trustedProxyMatch.ipv6Root { - t.Fatal("withCallOptions() should preserve the existing IPv6 matcher when trusted proxy prefixes are unchanged") - } - - if !effective.trustedProxyMatch.contains(netip.MustParseAddr("10.2.3.4")) { - t.Fatal("preserved matcher should still trust configured IPv4 CIDRs") - } - - if !effective.trustedProxyMatch.contains(netip.MustParseAddr("2001:db8::10")) { - t.Fatal("preserved matcher should still trust configured IPv6 CIDRs") - } -} - -func TestConfig_WithCallOptions_RebuildsTrustedProxyMatcherWithPrefixOverride(t *testing.T) { - base, err := configFromOptions( - WithTrustedProxyPrefixes(netip.MustParsePrefix("10.0.0.0/8")), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("configFromOptions() error = %v", err) - } - - effective, err := base.withCallOptions(WithCallTrustedProxyPrefixes(netip.MustParsePrefix("192.168.0.0/16"))) - if err != nil { - t.Fatalf("withCallOptions() error = %v", err) - } - - if !effective.trustedProxyMatch.initialized { - t.Fatal("trusted proxy matcher should be initialized") - } - - if !effective.trustedProxyMatch.contains(netip.MustParseAddr("192.168.1.2")) { - t.Fatal("expected overridden matcher to trust new CIDR range") - } - - if effective.trustedProxyMatch.contains(netip.MustParseAddr("10.1.2.3")) { - t.Fatal("expected overridden matcher to stop trusting previous CIDR range") - } -} - -func TestNew_InvalidBoundsAndDuplicatePriority(t *testing.T) { - tests := []struct { - name string - opts []Option - wantSources []Source - wantErrText string - }{ - { - name: "negative min trusted proxies", - opts: []Option{WithMinTrustedProxies(-1)}, - wantErrText: "minTrustedProxies must be >= 0", - }, - { - name: "negative max trusted proxies", - opts: []Option{WithMaxTrustedProxies(-1)}, - wantErrText: "maxTrustedProxies must be >= 0", - }, - { - name: "zero max chain length", - opts: []Option{WithMaxChainLength(0)}, - wantErrText: "maxChainLength must be > 0", - }, - { - name: "negative max chain length", - opts: []Option{WithMaxChainLength(-1)}, - wantErrText: "maxChainLength must be > 0", - }, - { - name: "duplicate source in priority", - opts: []Option{WithSourcePriority(SourceRemoteAddr, SourceRemoteAddr)}, wantErrText: "duplicate source", }, { - name: "duplicate source after canonicalization", - opts: []Option{WithTrustedLoopbackProxy(), WithSourcePriority(SourceXForwardedFor, HeaderSource("X-Forwarded-For"))}, - wantErrText: "duplicate source", + name: "resolver-only static fallback source", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.Sources = []Source{SourceStaticFallback} + return cfg + }, + wantErrText: "resolver-only and cannot be used in Config.Sources", }, { - name: "distinct custom headers with different runtime keys", - opts: []Option{WithTrustedLoopbackProxy(), WithSourcePriority(HeaderSource("Foo-Bar"), HeaderSource("Foo_Bar"))}, + name: "distinct custom headers with different runtime keys", + buildConfig: func() Config { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{HeaderSource("Foo-Bar"), HeaderSource("Foo_Bar")} + return cfg + }, wantSources: []Source{HeaderSource("Foo-Bar"), HeaderSource("Foo_Bar")}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - extractor, err := New(tt.opts...) + extractor, err := New(tt.buildConfig()) if tt.wantErrText == "" && tt.wantSources != nil { if err != nil { t.Fatalf("New() error = %v", err) @@ -743,12 +418,7 @@ func TestNew_InvalidBoundsAndDuplicatePriority(t *testing.T) { } got := errorTextStateOf(err, tt.wantErrText) - - want := errorTextState{ - HasErr: true, - ContainsText: true, - } - + want := errorTextState{HasErr: true, ContainsText: true} if diff := cmp.Diff(want, got); diff != "" { t.Fatalf("New() error mismatch (-want +got):\n%s", diff) } @@ -757,7 +427,10 @@ func TestNew_InvalidBoundsAndDuplicatePriority(t *testing.T) { } func TestWithTrustedPrivateProxyRanges_AddsExpectedCIDRs(t *testing.T) { - extractor, err := New(WithTrustedPrivateProxyRanges()) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = PrivateProxyPrefixes() + + extractor, err := New(cfg) if err != nil { t.Fatalf("New() error = %v", err) } @@ -770,6 +443,33 @@ func TestWithTrustedPrivateProxyRanges_AddsExpectedCIDRs(t *testing.T) { } } +func TestProxyPrefixesFromAddrs(t *testing.T) { + t.Run("valid addrs", func(t *testing.T) { + prefixes, err := ProxyPrefixesFromAddrs( + netip.MustParseAddr("1.1.1.1"), + netip.MustParseAddr("2001:db8::1"), + ) + if err != nil { + t.Fatalf("ProxyPrefixesFromAddrs() error = %v", err) + } + + want := []string{"1.1.1.1/32", "2001:db8::1/128"} + if diff := cmp.Diff(want, cidrStrings(prefixes)); diff != "" { + t.Fatalf("prefixes mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("invalid addr", func(t *testing.T) { + _, err := ProxyPrefixesFromAddrs(netip.Addr{}) + if err == nil { + t.Fatal("ProxyPrefixesFromAddrs() error = nil, want non-nil") + } + if !strings.Contains(err.Error(), "invalid proxy address") { + t.Fatalf("ProxyPrefixesFromAddrs() error = %q, want containing %q", err.Error(), "invalid proxy address") + } + }) +} + func TestStringers(t *testing.T) { tests := []struct { name string @@ -779,9 +479,6 @@ func TestStringers(t *testing.T) { {name: "rightmost selection", got: RightmostUntrustedIP.String(), want: "rightmost_untrusted"}, {name: "leftmost selection", got: LeftmostUntrustedIP.String(), want: "leftmost_untrusted"}, {name: "unknown selection", got: ChainSelection(999).String(), want: "unknown"}, - {name: "strict mode", got: SecurityModeStrict.String(), want: "strict"}, - {name: "lax mode", got: SecurityModeLax.String(), want: "lax"}, - {name: "unknown mode", got: SecurityMode(999).String(), want: "unknown"}, } for _, tt := range tests { diff --git a/config_validation.go b/config_validation.go deleted file mode 100644 index bd82ab4..0000000 --- a/config_validation.go +++ /dev/null @@ -1,90 +0,0 @@ -package clientip - -import ( - "fmt" -) - -func (c *config) validate() error { - if c.minTrustedProxies < 0 { - return fmt.Errorf("minTrustedProxies must be >= 0, got %d", c.minTrustedProxies) - } - if c.maxTrustedProxies < 0 { - return fmt.Errorf("maxTrustedProxies must be >= 0, got %d", c.maxTrustedProxies) - } - if c.maxTrustedProxies > 0 && c.minTrustedProxies > c.maxTrustedProxies { - return fmt.Errorf("minTrustedProxies (%d) cannot exceed maxTrustedProxies (%d)", c.minTrustedProxies, c.maxTrustedProxies) - } - if c.minTrustedProxies > 0 && len(c.trustedProxyCIDRs) == 0 { - return fmt.Errorf("minTrustedProxies > 0 requires trusted proxy prefixes to be configured for security validation; to skip validation and trust all proxies, use WithTrustedProxyPrefixes(netip.MustParsePrefix(\"0.0.0.0/0\"), netip.MustParsePrefix(\"::/0\"))") - } - if c.maxChainLength <= 0 { - return fmt.Errorf("maxChainLength must be > 0, got %d", c.maxChainLength) - } - if !c.chainSelection.valid() { - return fmt.Errorf("invalid chain selection %d (must be RightmostUntrustedIP=1 or LeftmostUntrustedIP=2)", c.chainSelection) - } - if !c.securityMode.valid() { - return fmt.Errorf("invalid security mode %d (must be SecurityModeStrict=1 or SecurityModeLax=2)", c.securityMode) - } - if len(c.sourcePriority) == 0 { - return fmt.Errorf("at least one source required in priority list") - } - - hasHeaderSource, hasChainSource, err := c.validateSourcePriority() - if err != nil { - return err - } - - if hasChainSource && c.chainSelection == LeftmostUntrustedIP && len(c.trustedProxyCIDRs) == 0 { - return fmt.Errorf("LeftmostUntrustedIP selection requires trusted proxy prefixes to be configured; without trusted-proxy validation, this selection provides no security benefit over RightmostUntrustedIP") - } - - if hasHeaderSource && len(c.trustedProxyCIDRs) == 0 { - return fmt.Errorf("header-based sources require trusted proxy prefixes; configure WithTrustedProxyPrefixes or trust helpers such as WithTrustedLoopbackProxy, WithTrustedPrivateProxyRanges, WithTrustedLocalProxyDefaults, or WithTrustedProxyAddrs, or use WithCallTrustedProxyPrefixes for a single extraction") - } - - if isNilValue(c.logger) { - return fmt.Errorf("logger cannot be nil") - } - if isNilValue(c.metrics) { - return fmt.Errorf("metrics cannot be nil") - } - return nil -} - -func (c *config) validateSourcePriority() (hasHeaderSource, hasChainSource bool, err error) { - seen := make(map[Source]struct{}, len(c.sourcePriority)) - seenForwarded := false - seenXFF := false - - for _, source := range c.sourcePriority { - source = canonicalSource(source) - if !source.valid() { - return false, false, fmt.Errorf("source names cannot be empty") - } - - if _, ok := seen[source]; ok { - return false, false, fmt.Errorf("duplicate source %q in priority list", source) - } - seen[source] = struct{}{} - - if source.kind != sourceRemoteAddr { - hasHeaderSource = true - } - - switch source.kind { - case sourceForwarded: - seenForwarded = true - hasChainSource = true - case sourceXForwardedFor: - seenXFF = true - hasChainSource = true - } - } - - if seenForwarded && seenXFF { - return false, false, fmt.Errorf("priority cannot include both %q and %q; choose one proxy chain header", builtinSource(sourceForwarded), builtinSource(sourceXForwardedFor)) - } - - return hasHeaderSource, hasChainSource, nil -} diff --git a/doc.go b/doc.go index e2abcf8..dbfbfe9 100644 --- a/doc.go +++ b/doc.go @@ -1,143 +1,93 @@ // Package clientip provides secure client IP extraction from HTTP requests and -// framework-agnostic request inputs with support for proxy chains, trusted -// proxy validation, and multiple header sources. -// -// # Features -// -// - Security-first design with protection against IP spoofing and header injection -// - Flexible proxy configuration with min/max trusted proxy ranges in proxy chains -// - Multiple source support: Forwarded, X-Forwarded-For, X-Real-IP, RemoteAddr, custom headers -// - Framework-friendly RequestInput API for non-net/http integrations -// - Typed source configuration via opaque Source values, built-in Source variables, and HeaderSource(...) -// - Safe defaults: RemoteAddr-only unless header sources are explicitly configured -// - Deployment presets for common topologies (direct, loopback proxy, VM proxy) -// - Per-call policy overrides via CallOption builders such as WithCallSecurityMode -// - Optional observability with context-aware logging and pluggable metrics -// - Type-safe using modern Go netip.Addr +// framework-agnostic inputs with trusted proxy validation, explicit source +// modeling, and request-scoped resolver caching. // -// # Basic Usage +// # Choose The API +// +// Resolver is the primary integration-facing API. +// +// Use Resolver when you want to: +// +// - resolve once per request or Input +// - reuse the result later from context +// - choose between strict and preferred semantics +// - keep explicit fallback behavior on a separate layer from extraction +// +// Extractor remains the low-level strict primitive. +// +// Use Extractor when you want one direct extraction call without request-scoped +// caching or preferred fallback. +// +// Input is the framework-agnostic carrier for non-net/http integrations. +// +// ParseRemoteAddr and ClassifyError are small helpers for explicit fallback and +// policy code. ClassifyError keeps typed errors intact while providing a +// smaller ResultKind layer for middleware and policy branches. // -// Simple extraction without proxy configuration: +// # Basic Usage // -// extractor, err := clientip.New() +// extractor, err := clientip.New(clientip.PresetLoopbackReverseProxy()) // if err != nil { // log.Fatal(err) // } // -// extraction, err := extractor.Extract(req) +// resolver, err := clientip.NewResolver(extractor, clientip.ResolverConfig{}) // if err != nil { -// log.Printf("extract failed: %v", err) +// log.Fatal(err) +// } +// +// req, resolution := resolver.ResolveStrict(req) +// if resolution.Err != nil { +// log.Printf("resolve failed: %v", resolution.Err) // return // } // -// fmt.Printf("Client IP: %s from %s\n", extraction.IP, extraction.Source) +// fmt.Printf("Client IP: %s from %s\n", resolution.IP, resolution.Source) // -// Framework-agnostic input is available via ExtractFrom: +// if cached, ok := clientip.StrictResolutionFromContext(req.Context()); ok { +// fmt.Printf("Cached client IP: %s\n", cached.IP) +// } +// +// Framework-agnostic input is available through ExtractInput and Resolver's +// input methods. Resolver methods return the updated request or Input so cached +// resolution state can flow through the call path: // -// extraction, err := extractor.ExtractFrom(clientip.RequestInput{ +// input, resolution := resolver.ResolveInputStrict(clientip.Input{ // Context: ctx, // RemoteAddr: remoteAddr, // Path: path, // Headers: headerProvider, // }) +// _ = input // -// Call options work with both Extract and ExtractFrom: -// -// extraction, err := extractor.ExtractFrom(input, -// clientip.WithCallSourcePriority( -// clientip.HeaderSource("CF-Connecting-IP"), -// clientip.SourceRemoteAddr, -// ), -// ) -// -// # Behind Reverse Proxy +// # Config, Sources, And Security // -// Configure trusted proxy prefixes with flexible min/max proxy count: +// Config stays flat in the current public API. Presets return Config values and +// can be tweaked before construction. // -// cidrs, _ := clientip.ParseCIDRs("10.0.0.0/8", "172.16.0.0/12") -// extractor, err := clientip.New( -// clientip.WithTrustedProxyPrefixes(cidrs...), // Trust upstream proxy ranges -// clientip.WithMinTrustedProxies(0), // Count trusted proxies present in proxy headers -// clientip.WithMaxTrustedProxies(2), -// clientip.WithSourcePriority(clientip.SourceXForwardedFor, clientip.SourceRemoteAddr), -// clientip.WithChainSelection(clientip.RightmostUntrustedIP), -// clientip.WithAllowPrivateIPs(false), -// ) +// Source values stay public and opaque. Use the built-in extractor sources for +// request-derived extraction, SourceStaticFallback for resolver static fallback +// results, and HeaderSource for custom headers. // -// # Custom Headers +// Extractor walks Config.Sources in order. Source-unavailable errors allow the +// next source to run, while malformed headers, proxy-trust failures, chain +// limits, and implausible client IPs remain terminal. // -// Support for cloud providers and custom proxy headers: +// Header-based sources require trusted upstream proxy ranges. Configure +// TrustedProxyPrefixes directly, optionally using LoopbackProxyPrefixes, +// PrivateProxyPrefixes, LocalProxyPrefixes, or ProxyPrefixesFromAddrs. // -// extractor, _ := clientip.New( -// clientip.WithTrustedLoopbackProxy(), -// clientip.WithSourcePriority( -// clientip.HeaderSource("CF-Connecting-IP"), // Cloudflare -// clientip.SourceXForwardedFor, -// clientip.SourceRemoteAddr, -// ), -// ) -// -// Header sources require trusted upstream proxy ranges. Use -// WithTrustedProxyPrefixes(with ParseCIDRs for string inputs) or helper options like -// WithTrustedLoopbackProxy, WithTrustedPrivateProxyRanges, -// WithTrustedLocalProxyDefaults, or WithTrustedProxyAddrs. -// -// Presets are available for common setups: -// -// extractor, _ := clientip.New(clientip.PresetVMReverseProxy()) +// Preferred resolver fallback is explicit and operationally useful, but it is +// not suitable for authorization or trust-boundary enforcement. // // # Observability // -// Add logging and metrics for production monitoring: -// (Prometheus adapter package: github.com/abczzz13/clientip/prometheus) -// The logger receives req.Context(), allowing trace/span IDs to flow through. -// -// import clientipprom "github.com/abczzz13/clientip/prometheus" -// -// metrics, _ := clientipprom.New() -// -// extractor, err := clientip.New( -// clientip.WithTrustedProxyPrefixes(cidrs...), -// clientip.WithMinTrustedProxies(0), -// clientip.WithMaxTrustedProxies(3), -// clientip.WithSourcePriority(clientip.SourceXForwardedFor, clientip.SourceRemoteAddr), -// clientip.WithLogger(slog.Default()), -// clientip.WithMetrics(metrics), -// ) -// -// # Security Considerations -// -// The package includes several security features: -// -// - Detection of malformed Forwarded headers and duplicate single-IP header values -// - Immediate proxy trust enforcement before honoring Forwarded/X-Forwarded-For -// - Validation of proxy counts (min/max enforcement) -// - Chain length limits to prevent DoS -// - Rejection of invalid/implausible IPs (loopback, multicast, etc.) -// - Optional private IP filtering and explicit reserved CIDR allowlisting -// - Strict fail-closed behavior by default (SecurityModeStrict) -// -// # Security Anti-Patterns -// -// - Do not combine multiple competing header sources for security decisions. -// - Do not use SecurityModeLax for ACL/risk/authz enforcement paths. -// - Do not trust broad proxy CIDRs unless they are truly controlled by your edge. -// -// # Security Modes -// -// Security behavior can be configured per extractor: -// -// - SecurityModeStrict (default): fail closed on security-significant errors and invalid present source values. -// - SecurityModeLax: allow fallback to lower-priority sources for those errors. -// -// Example: -// -// extractor, _ := clientip.New( -// clientip.WithSecurityMode(clientip.SecurityModeLax), -// ) +// Logger and Metrics remain separate public interfaces. // -// # Thread Safety +// Security event labels are exported as SecurityEvent... constants so adapters +// can depend on stable names. // -// Extractor instances are safe for concurrent use. They are typically created -// once at application startup and reused across all requests. +// Preferred resolver fallback remains result-only in this phase. Inspect +// Resolution.FallbackUsed rather than expecting separate fallback log or metric +// signals. package clientip diff --git a/events.go b/events.go deleted file mode 100644 index ea956bc..0000000 --- a/events.go +++ /dev/null @@ -1,14 +0,0 @@ -package clientip - -const ( - securityEventMultipleHeaders = "multiple_headers" - securityEventChainTooLong = "chain_too_long" - securityEventUntrustedProxy = "untrusted_proxy" - securityEventNoTrustedProxies = "no_trusted_proxies" - securityEventTooFewTrustedProxies = "too_few_trusted_proxies" - securityEventTooManyTrustedProxies = "too_many_trusted_proxies" - securityEventInvalidIP = "invalid_ip" - securityEventReservedIP = "reserved_ip" - securityEventPrivateIP = "private_ip" - securityEventMalformedForwarded = "malformed_forwarded" -) diff --git a/example_test.go b/example_test.go index 0b26d10..dfd4ca3 100644 --- a/example_test.go +++ b/example_test.go @@ -2,210 +2,148 @@ package clientip_test import ( "context" - "errors" "fmt" - "log/slog" "net/http" - "net/netip" "net/textproto" - "os" "github.com/abczzz13/clientip" ) -func ExampleNew_simple() { - extractor, err := clientip.New() +func ExampleResolver_ResolveStrict() { + extractor, err := clientip.New(clientip.PresetLoopbackReverseProxy()) if err != nil { panic(err) } - req := &http.Request{RemoteAddr: "8.8.4.4:12345", Header: make(http.Header)} - - ip, err := extractor.ExtractAddr(req) + resolver, err := clientip.NewResolver(extractor, clientip.ResolverConfig{}) if err != nil { panic(err) } - fmt.Printf("Client IP: %s\n", ip) -} - -func ExamplePresetVMReverseProxy() { - extractor, _ := clientip.New(clientip.PresetVMReverseProxy()) - - req := &http.Request{RemoteAddr: "127.0.0.1:12345", Header: make(http.Header)} - req.Header.Set("X-Forwarded-For", "1.1.1.1") - - extraction, _ := extractor.Extract(req) - fmt.Println(extraction.IP, extraction.Source) - // Output: 1.1.1.1 x_forwarded_for -} - -func ExamplePresetPreferredHeaderThenXFFLax() { - extractor, _ := clientip.New( - clientip.WithTrustedLoopbackProxy(), - clientip.PresetPreferredHeaderThenXFFLax("X-Frontend-IP"), - ) - req := &http.Request{RemoteAddr: "127.0.0.1:12345", Header: make(http.Header)} - req.Header.Set("X-Frontend-IP", "not-an-ip") req.Header.Set("X-Forwarded-For", "8.8.8.8") - extraction, _ := extractor.Extract(req) - fmt.Println(extraction.IP, extraction.Source) - // Output: 8.8.8.8 x_forwarded_for -} - -func ExampleNew_forwarded() { - extractor, _ := clientip.New( - clientip.WithTrustedLoopbackProxy(), - clientip.WithSourcePriority(clientip.SourceForwarded, clientip.SourceRemoteAddr), - ) + req, resolution := resolver.ResolveStrict(req) + if resolution.Err != nil { + panic(resolution.Err) + } - req := &http.Request{RemoteAddr: "127.0.0.1:12345", Header: make(http.Header)} - req.Header.Set("Forwarded", "for=1.1.1.1") + fmt.Println(resolution.IP, resolution.Source, resolution.FallbackUsed) - extraction, _ := extractor.Extract(req) - fmt.Println(extraction.IP, extraction.Source) - // Output: 1.1.1.1 forwarded + cached, ok := clientip.StrictResolutionFromContext(req.Context()) + fmt.Println(ok, cached.IP) + // Output: + // 8.8.8.8 x_forwarded_for false + // true 8.8.8.8 } -func ExampleNew_withOptions() { - cidrs, _ := netip.ParsePrefix("10.0.0.0/8") - - extractor, err := clientip.New( - clientip.WithTrustedProxyPrefixes(cidrs), - clientip.WithMinTrustedProxies(1), - clientip.WithMaxTrustedProxies(2), - clientip.WithSourcePriority(clientip.SourceXForwardedFor, clientip.SourceRemoteAddr), - clientip.WithAllowPrivateIPs(false), - clientip.WithLogger(slog.New(slog.NewTextHandler(os.Stdout, nil))), - ) +func ExampleResolver_ResolvePreferred() { + extractor, err := clientip.New(clientip.Config{ + TrustedProxyPrefixes: clientip.LoopbackProxyPrefixes(), + Sources: []clientip.Source{clientip.SourceXForwardedFor}, + }) if err != nil { panic(err) } - req := &http.Request{RemoteAddr: "10.0.1.5:12345", Header: make(http.Header)} - req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.1.5") - - extraction, _ := extractor.Extract(req) - fmt.Printf("Client IP: %s from source: %s\n", extraction.IP, extraction.Source) -} - -func ExampleWithAllowedReservedClientPrefixes() { - extractor, _ := clientip.New( - clientip.WithAllowedReservedClientPrefixes(netip.MustParsePrefix("198.51.100.0/24")), - ) + resolver, err := clientip.NewResolver(extractor, clientip.ResolverConfig{PreferredFallback: clientip.PreferredFallbackRemoteAddr}) + if err != nil { + panic(err) + } - req := &http.Request{RemoteAddr: "198.51.100.10:12345", Header: make(http.Header)} + req := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: make(http.Header)} - extraction, _ := extractor.Extract(req) - fmt.Println(extraction.IP, extraction.Source) - // Output: 198.51.100.10 remote_addr -} + _, resolution := resolver.ResolvePreferred(req) + if resolution.Err != nil { + panic(resolution.Err) + } -func ExampleNew_flexibleProxyRange() { - cidrs, _ := netip.ParsePrefix("10.0.0.0/8") - - extractor, _ := clientip.New( - clientip.WithTrustedProxyPrefixes(cidrs), - clientip.WithMinTrustedProxies(1), - clientip.WithMaxTrustedProxies(3), - clientip.WithSourcePriority(clientip.SourceXForwardedFor, clientip.SourceRemoteAddr), - ) - - req1 := &http.Request{RemoteAddr: "10.0.0.1:12345", Header: make(http.Header)} - req1.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1") - extraction1, _ := extractor.Extract(req1) - fmt.Printf("1 proxy: %s\n", extraction1.IP) - - req2 := &http.Request{RemoteAddr: "10.0.0.3:12345", Header: make(http.Header)} - req2.Header.Set("X-Forwarded-For", "8.8.8.8, 10.0.0.2, 10.0.0.3") - extraction2, _ := extractor.Extract(req2) - fmt.Printf("2 proxies: %s\n", extraction2.IP) + fmt.Println(resolution.IP, resolution.Source, resolution.FallbackUsed) + // Output: + // 1.1.1.1 remote_addr true } -func ExampleNew_cloudflare() { - extractor, _ := clientip.New( - clientip.WithTrustedLoopbackProxy(), - clientip.WithSourcePriority(clientip.HeaderSource("CF-Connecting-IP"), clientip.SourceXForwardedFor, clientip.SourceRemoteAddr), - ) - - req := &http.Request{RemoteAddr: "127.0.0.1:12345", Header: make(http.Header)} - req.Header.Set("CF-Connecting-IP", "1.1.1.1") +func ExamplePreferredResolutionFromContext() { + extractor, err := clientip.New(clientip.DefaultConfig()) + if err != nil { + panic(err) + } - extraction, _ := extractor.Extract(req) - fmt.Printf("Client IP: %s (from %s)\n", extraction.IP, extraction.Source) -} + resolver, err := clientip.NewResolver(extractor, clientip.ResolverConfig{}) + if err != nil { + panic(err) + } -func ExampleHeader() { - extractor, _ := clientip.New( - clientip.WithTrustedLoopbackProxy(), - clientip.WithSourcePriority(clientip.HeaderSource("X-Custom-IP"), clientip.SourceRemoteAddr), - ) + req := &http.Request{RemoteAddr: "8.8.4.4:12345", Header: make(http.Header)} - req := &http.Request{RemoteAddr: "127.0.0.1:12345", Header: make(http.Header)} - req.Header.Set("X-Custom-IP", "8.8.8.8") + req, resolution := resolver.ResolvePreferred(req) + if resolution.Err != nil { + panic(resolution.Err) + } - ip, _ := extractor.ExtractAddr(req) - fmt.Printf("IP: %s\n", ip) + cached, ok := clientip.PreferredResolutionFromContext(req.Context()) + fmt.Println(ok, cached.IP == resolution.IP, cached.Source) + // Output: + // true true remote_addr } -func ExampleWithChainSelection_leftmostUntrusted() { - cloudflareCIDRs, _ := netip.ParsePrefix("173.245.48.0/20") +func ExampleExtractor_Extract() { + extractor, err := clientip.New(clientip.DefaultConfig()) + if err != nil { + panic(err) + } - extractor, _ := clientip.New( - clientip.WithTrustedProxyPrefixes(cloudflareCIDRs), - clientip.WithMinTrustedProxies(1), - clientip.WithMaxTrustedProxies(3), - clientip.WithSourcePriority(clientip.SourceXForwardedFor, clientip.SourceRemoteAddr), - clientip.WithChainSelection(clientip.LeftmostUntrustedIP), - ) + req := &http.Request{RemoteAddr: "8.8.4.4:12345", Header: make(http.Header)} - req := &http.Request{RemoteAddr: "173.245.48.5:443", Header: make(http.Header)} - req.Header.Set("X-Forwarded-For", "1.1.1.1, 173.245.48.5") + extraction, err := extractor.Extract(req) + if err != nil { + panic(err) + } - ip, _ := extractor.ExtractAddr(req) - fmt.Printf("Client IP: %s\n", ip) + fmt.Println(extraction.IP, extraction.Source) + // Output: + // 8.8.4.4 remote_addr } -func ExampleWithSecurityMode_strict() { - extractor, _ := clientip.New( - clientip.WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - clientip.WithSourcePriority(clientip.SourceForwarded, clientip.SourceRemoteAddr), - clientip.WithSecurityMode(clientip.SecurityModeStrict), - ) - - req := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: make(http.Header)} - req.Header.Set("Forwarded", `for="1.1.1.1`) +func ExamplePresetVMReverseProxy() { + extractor, err := clientip.New(clientip.PresetVMReverseProxy()) + if err != nil { + panic(err) + } - extraction, err := extractor.Extract(req) - fmt.Println(err == nil, errors.Is(err, clientip.ErrInvalidForwardedHeader), extraction.Source) - // Output: false true forwarded -} + resolver, err := clientip.NewResolver(extractor, clientip.ResolverConfig{}) + if err != nil { + panic(err) + } -func ExampleWithSecurityMode_lax() { - extractor, _ := clientip.New( - clientip.WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - clientip.WithSourcePriority(clientip.SourceForwarded, clientip.SourceRemoteAddr), - clientip.WithSecurityMode(clientip.SecurityModeLax), - ) + req := &http.Request{RemoteAddr: "127.0.0.1:12345", Header: make(http.Header)} + req.Header.Set("X-Forwarded-For", "1.1.1.1") - req := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: make(http.Header)} - req.Header.Set("Forwarded", `for="1.1.1.1`) + _, resolution := resolver.ResolveStrict(req) + if resolution.Err != nil { + panic(resolution.Err) + } - extraction, _ := extractor.Extract(req) - fmt.Println(extraction.IP, extraction.Source) - // Output: 1.1.1.1 remote_addr + fmt.Println(resolution.IP, resolution.Source) + // Output: 1.1.1.1 x_forwarded_for } -func ExampleExtractor_ExtractFrom() { - extractor, _ := clientip.New( - clientip.WithTrustedLoopbackProxy(), - clientip.WithSourcePriority(clientip.HeaderSource("CF-Connecting-IP"), clientip.SourceRemoteAddr), - ) +func ExampleResolver_ResolveInputPreferred() { + extractor, err := clientip.New(clientip.Config{ + TrustedProxyPrefixes: clientip.LoopbackProxyPrefixes(), + Sources: []clientip.Source{clientip.HeaderSource("CF-Connecting-IP"), clientip.SourceRemoteAddr}, + }) + if err != nil { + panic(err) + } + + resolver, err := clientip.NewResolver(extractor, clientip.ResolverConfig{}) + if err != nil { + panic(err) + } cfHeader := textproto.CanonicalMIMEHeaderKey("CF-Connecting-IP") - input := clientip.RequestInput{ + input := clientip.Input{ Context: context.Background(), RemoteAddr: "127.0.0.1:12345", Path: "/framework-request", @@ -217,7 +155,15 @@ func ExampleExtractor_ExtractFrom() { }), } - extraction, _ := extractor.ExtractFrom(input) - fmt.Println(extraction.IP, extraction.Source) - // Output: 8.8.8.8 cf_connecting_ip + input, resolution := resolver.ResolveInputPreferred(input) + if resolution.Err != nil { + panic(resolution.Err) + } + + cached, ok := clientip.PreferredResolutionFromContext(input.Context) + fmt.Println(resolution.IP, resolution.Source, resolution.FallbackUsed) + fmt.Println(ok, cached.Source) + // Output: + // 8.8.8.8 cf_connecting_ip false + // true cf_connecting_ip } diff --git a/extract_from_test.go b/extract_from_test.go deleted file mode 100644 index 3d643e5..0000000 --- a/extract_from_test.go +++ /dev/null @@ -1,419 +0,0 @@ -package clientip - -import ( - "context" - "errors" - "net/http" - "net/netip" - "net/textproto" - "net/url" - "testing" -) - -type extractFromContextKey string - -type panicOnNilHeaderProvider struct{} - -func (p *panicOnNilHeaderProvider) Values(string) []string { - if p == nil { - panic("nil header provider should not be called") - } - - return nil -} - -func TestExtractFrom_ParityWithExtract(t *testing.T) { - extractor, err := New( - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - tests := []struct { - name string - headers []string - remote string - }{ - { - name: "remote_addr_only", - remote: "8.8.8.8:8080", - }, - { - name: "xff_success", - remote: "1.1.1.1:8080", - headers: []string{"8.8.8.8"}, - }, - { - name: "duplicate_xff", - remote: "1.1.1.1:8080", - headers: []string{"8.8.8.8", "9.9.9.9"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), extractFromContextKey("trace_id"), "trace-123") - req := (&http.Request{ - RemoteAddr: tt.remote, - Header: make(http.Header), - URL: &url.URL{Path: "/parity"}, - }).WithContext(ctx) - - for _, value := range tt.headers { - req.Header.Add("X-Forwarded-For", value) - } - - httpExtraction, httpErr := extractor.Extract(req) - - inputExtraction, inputErr := extractor.ExtractFrom(RequestInput{ - Context: req.Context(), - RemoteAddr: req.RemoteAddr, - Path: req.URL.Path, - Headers: req.Header, - }) - - if httpErr != nil { - t.Fatalf("Extract() error = %v", httpErr) - } - if inputErr != nil { - t.Fatalf("ExtractFrom() error = %v", inputErr) - } - - if inputExtraction != httpExtraction { - t.Fatalf("extraction mismatch: ExtractFrom=%+v Extract=%+v", inputExtraction, httpExtraction) - } - }) - } -} - -func TestExtractFrom_HeaderValuesFunc(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(HeaderSource("CF-Connecting-IP"), SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - cfHeader := textproto.CanonicalMIMEHeaderKey("CF-Connecting-IP") - requestedHeaders := make([]string, 0, 1) - headers := HeaderValuesFunc(func(name string) []string { - requestedHeaders = append(requestedHeaders, name) - if name == cfHeader { - return []string{"9.9.9.9"} - } - return nil - }) - - extraction, err := extractor.ExtractFrom(RequestInput{ - RemoteAddr: "127.0.0.1:8080", - Headers: headers, - }) - if err != nil { - t.Fatalf("ExtractFrom() error = %v", err) - } - - if got, want := extraction.Source, HeaderSource("CF-Connecting-IP"); got != want { - t.Fatalf("source = %q, want %q", got, want) - } - if got, want := extraction.IP, netip.MustParseAddr("9.9.9.9"); got != want { - t.Fatalf("ip = %s, want %s", got, want) - } - if len(requestedHeaders) != 1 || requestedHeaders[0] != cfHeader { - t.Fatalf("requested headers = %v, want [%q]", requestedHeaders, cfHeader) - } -} - -func TestExtractFrom_RemoteAddrOnlyDoesNotRequestHeaders(t *testing.T) { - extractor, err := New() - if err != nil { - t.Fatalf("New() error = %v", err) - } - - requested := 0 - input := RequestInput{ - RemoteAddr: "8.8.8.8:8080", - Headers: HeaderValuesFunc(func(name string) []string { - requested++ - return nil - }), - } - - extraction, err := extractor.ExtractFrom(input) - if err != nil { - t.Fatalf("ExtractFrom() error = %v", err) - } - if got, want := extraction.IP.String(), "8.8.8.8"; got != want { - t.Fatalf("ip = %q, want %q", got, want) - } - if requested != 0 { - t.Fatalf("header provider called %d times, want 0", requested) - } -} - -func TestExtractFrom_CallOptionSourcePriority_UsesEffectiveHeaders(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - t.Run("custom header override requests only effective header", func(t *testing.T) { - cfHeader := textproto.CanonicalMIMEHeaderKey("CF-Connecting-IP") - requestedHeaders := make([]string, 0, 1) - - extraction, err := extractor.ExtractFrom( - RequestInput{ - RemoteAddr: "127.0.0.1:8080", - Headers: HeaderValuesFunc(func(name string) []string { - requestedHeaders = append(requestedHeaders, name) - switch name { - case cfHeader: - return []string{"9.9.9.9"} - case "X-Forwarded-For": - return []string{"8.8.8.8"} - default: - return nil - } - }), - }, - WithCallSourcePriority(HeaderSource("cf-connecting-ip"), SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("ExtractFrom() error = %v", err) - } - - if got, want := extraction.IP, netip.MustParseAddr("9.9.9.9"); got != want { - t.Fatalf("ip = %s, want %s", got, want) - } - if got, want := extraction.Source, HeaderSource("CF-Connecting-IP"); got != want { - t.Fatalf("source = %q, want %q", got, want) - } - if len(requestedHeaders) != 1 || requestedHeaders[0] != cfHeader { - t.Fatalf("requested headers = %v, want [%q]", requestedHeaders, cfHeader) - } - }) - - t.Run("remote addr override skips header provider", func(t *testing.T) { - requested := 0 - - extraction, err := extractor.ExtractFrom( - RequestInput{ - RemoteAddr: "8.8.8.8:8080", - Headers: HeaderValuesFunc(func(name string) []string { - requested++ - return []string{"9.9.9.9"} - }), - }, - WithCallSourcePriority(SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("ExtractFrom() error = %v", err) - } - - if got, want := extraction.IP, netip.MustParseAddr("8.8.8.8"); got != want { - t.Fatalf("ip = %s, want %s", got, want) - } - if got, want := extraction.Source, SourceRemoteAddr; got != want { - t.Fatalf("source = %q, want %q", got, want) - } - if requested != 0 { - t.Fatalf("header provider called %d times, want 0", requested) - } - }) -} - -func TestExtractFrom_TypedNilHeaderProviderTreatedAsAbsent(t *testing.T) { - extractor, err := New( - WithTrustedProxyAddrs(netip.MustParseAddr("8.8.8.8")), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - var nilHTTPHeader *http.Header - var nilProvider *panicOnNilHeaderProvider - - tests := []struct { - name string - headers HeaderValues - }{ - {name: "typed_nil_http_header_pointer", headers: nilHTTPHeader}, - {name: "typed_nil_custom_provider_pointer", headers: nilProvider}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - extraction, extractErr := extractor.ExtractFrom(RequestInput{ - RemoteAddr: "8.8.8.8:8080", - Headers: tt.headers, - }) - if extractErr != nil { - t.Fatalf("ExtractFrom() error = %v", extractErr) - } - - if got, want := extraction.Source, SourceRemoteAddr; got != want { - t.Fatalf("source = %q, want %q", got, want) - } - if got, want := extraction.IP, netip.MustParseAddr("8.8.8.8"); got != want { - t.Fatalf("ip = %s, want %s", got, want) - } - }) - } -} - -func TestExtractFrom_RemoteAddrOnly_RespectsCanceledContext(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - t.Run("default_remote_addr_priority", func(t *testing.T) { - extractor, err := New() - if err != nil { - t.Fatalf("New() error = %v", err) - } - - _, extractErr := extractor.ExtractFrom(RequestInput{ - Context: ctx, - RemoteAddr: "8.8.8.8:8080", - }) - if !errors.Is(extractErr, context.Canceled) { - t.Fatalf("error = %v, want context.Canceled", extractErr) - } - }) - - t.Run("call_option_to_remote_addr_priority", func(t *testing.T) { - extractor, err := New( - WithTrustedProxyAddrs(netip.MustParseAddr("8.8.8.8")), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - _, extractErr := extractor.ExtractFrom( - RequestInput{ - Context: ctx, - RemoteAddr: "8.8.8.8:8080", - }, - WithCallSourcePriority(SourceRemoteAddr), - ) - if !errors.Is(extractErr, context.Canceled) { - t.Fatalf("error = %v, want context.Canceled", extractErr) - } - }) -} - -func TestExtractFrom_CanceledContext_DoesNotRequestHeaders(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - extractor, err := New( - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - requested := 0 - _, extractErr := extractor.ExtractFrom(RequestInput{ - Context: ctx, - RemoteAddr: "1.1.1.1:8080", - Headers: HeaderValuesFunc(func(name string) []string { - requested++ - return []string{"8.8.8.8"} - }), - }) - if !errors.Is(extractErr, context.Canceled) { - t.Fatalf("error = %v, want context.Canceled", extractErr) - } - if requested != 0 { - t.Fatalf("header provider called %d times, want 0", requested) - } -} - -func TestExtractFrom_UsesInputContextAndPathInLogs(t *testing.T) { - logger := &capturedLogger{} - - extractor, err := New( - WithLogger(logger), - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceXForwardedFor), - WithMaxChainLength(1), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - ctx := context.WithValue(context.Background(), loggerTestContextKey("trace_id"), "trace-from-input") - headers := HeaderValuesFunc(func(name string) []string { - if name == "X-Forwarded-For" { - return []string{"8.8.8.8", "9.9.9.9"} - } - return nil - }) - - result, err := extractor.ExtractFrom(RequestInput{ - Context: ctx, - RemoteAddr: "1.1.1.1:8080", - Path: "/from-input", - Headers: headers, - }) - if err == nil && result.IP.IsValid() { - t.Fatal("expected extraction failure for overlong X-Forwarded-For chain") - } - if !errors.Is(err, ErrChainTooLong) { - t.Fatalf("error = %v, want ErrChainTooLong", err) - } - - entries := logger.snapshot() - if len(entries) != 1 { - t.Fatalf("logged entries = %d, want 1", len(entries)) - } - - entry := entries[0] - if got := entry.ctx.Value(loggerTestContextKey("trace_id")); got != "trace-from-input" { - t.Fatalf("trace context value = %v, want %q", got, "trace-from-input") - } - - assertCommonSecurityWarningAttrs( - t, - entry.attrs, - securityEventChainTooLong, - SourceXForwardedFor, - "/from-input", - "1.1.1.1:8080", - ) -} - -func TestExtractFrom_NilContextDefaultsBackground(t *testing.T) { - input := RequestInput{RemoteAddr: "8.8.8.8:8080"} - extractor, err := New() - if err != nil { - t.Fatalf("New() error = %v", err) - } - - extraction, err := extractor.ExtractFrom(input) - if err != nil { - t.Fatalf("ExtractFrom() error = %v", err) - } - if got, want := extraction.IP.String(), "8.8.8.8"; got != want { - t.Fatalf("IP = %q, want %q", got, want) - } - if got, want := extraction.Source, SourceRemoteAddr; got != want { - t.Fatalf("Source = %q, want %q", got, want) - } - - addr, err := extractor.ExtractAddrFrom(input) - if err != nil { - t.Fatalf("ExtractAddrFrom() error = %v", err) - } - if got, want := addr.String(), "8.8.8.8"; got != want { - t.Fatalf("IP = %q, want %q", got, want) - } -} diff --git a/extractor.go b/extractor.go index dffc6c6..5ca0fb8 100644 --- a/extractor.go +++ b/extractor.go @@ -1,7 +1,6 @@ package clientip import ( - "context" "errors" "fmt" "net/http" @@ -13,72 +12,56 @@ import ( // // Extractor instances are safe for concurrent reuse. type Extractor struct { - config *config - source sourceExtractor + config *config + source sourceExtractor + clientIP clientIPPolicy + proxy proxyPolicy } -// New creates an Extractor from one or more Option builders. -func New(opts ...Option) (*Extractor, error) { - cfg, err := configFromOptions(opts...) +// New creates an Extractor from a Config. +func New(public Config) (*Extractor, error) { + cfg, err := configFromPublic(public) if err != nil { return nil, fmt.Errorf("invalid configuration: %w", err) } - extractor := &Extractor{config: cfg} + extractor := &Extractor{ + config: cfg, + clientIP: clientIPPolicy{ + AllowPrivateIPs: cfg.allowPrivateIPs, + AllowReservedClientPrefixes: cfg.allowReservedClientPrefixes, + }, + proxy: proxyPolicy{ + TrustedProxyCIDRs: cfg.trustedProxyCIDRs, + TrustedProxyMatch: cfg.trustedProxyMatch, + MinTrustedProxies: cfg.minTrustedProxies, + MaxTrustedProxies: cfg.maxTrustedProxies, + }, + } extractor.source = extractor.buildSourceChain(cfg) return extractor, nil } -func (e *Extractor) buildSourceChain(cfg *config) sourceExtractor { - sources := make([]sourceExtractor, 0, len(cfg.sourcePriority)) - for _, configuredSource := range cfg.sourcePriority { - var source sourceExtractor - switch configuredSource.kind { - case sourceForwarded: - source = newForwardedSource(e) - case sourceXForwardedFor: - source = newForwardedForSource(e) - case sourceXRealIP: - source = newSingleHeaderSource(e, "X-Real-IP") - case sourceRemoteAddr: - source = newRemoteAddrSource(e) - default: - headerName, _ := configuredSource.headerKey() - source = newSingleHeaderSource(e, headerName) - } - sources = append(sources, source) - } - - return newChainedSource(e, sources...) -} - // Extract resolves client IP and metadata for the request. -// -// When call options are provided, they are applied left-to-right and applied only -// for this call. -func (e *Extractor) Extract(r *http.Request, callOpts ...CallOption) (Extraction, error) { +func (e *Extractor) Extract(r *http.Request) (Extraction, error) { if r == nil { return Extraction{}, ErrNilRequest } - ctx := r.Context() - - if len(callOpts) == 0 { - return e.extractWithSource(e.source, ctx, r) - } - - activeExtractor, activeSource, err := e.prepareCall(callOpts...) - if err != nil { - return Extraction{}, err + if len(e.config.sourceHeaderKeys) == 0 { + if ctx := r.Context(); ctx.Err() != nil { + return Extraction{}, ctx.Err() + } + return e.extractFromRemoteAddr(r.RemoteAddr) } - return activeExtractor.extractWithSource(activeSource, ctx, r) + return e.extractWithSource(e.source, requestViewFromRequest(r)) } // ExtractAddr resolves only the client IP address. -func (e *Extractor) ExtractAddr(r *http.Request, callOpts ...CallOption) (netip.Addr, error) { - extraction, err := e.Extract(r, callOpts...) +func (e *Extractor) ExtractAddr(r *http.Request) (netip.Addr, error) { + extraction, err := e.Extract(r) if err != nil { return netip.Addr{}, err } @@ -86,41 +69,25 @@ func (e *Extractor) ExtractAddr(r *http.Request, callOpts ...CallOption) (netip. return extraction.IP, nil } -// ExtractFrom resolves client IP and metadata from framework-agnostic request +// ExtractInput resolves client IP and metadata from framework-agnostic request // input. -// -// When call options are provided, they are applied left-to-right and applied only -// for this call. -func (e *Extractor) ExtractFrom(input RequestInput, callOpts ...CallOption) (Extraction, error) { - activeExtractor := e - activeSource := e.source - - if len(callOpts) > 0 { - var err error - activeExtractor, activeSource, err = e.prepareCall(callOpts...) - if err != nil { - return Extraction{}, err - } - } - +func (e *Extractor) ExtractInput(input Input) (Extraction, error) { ctx := requestInputContext(input) if err := ctx.Err(); err != nil { return Extraction{}, err } - if len(activeExtractor.config.sourceHeaderKeys) == 0 { - return activeExtractor.extractFromRemoteAddr(input.RemoteAddr) + if len(e.config.sourceHeaderKeys) == 0 { + return e.extractFromRemoteAddr(input.RemoteAddr) } - req := requestFromInput(input, activeExtractor.config.sourceHeaderKeys) - - return activeExtractor.extractWithSource(activeSource, ctx, req) + return e.extractWithSource(e.source, requestViewFromInput(input)) } -// ExtractAddrFrom resolves only the client IP address from framework-agnostic +// ExtractInputAddr resolves only the client IP address from framework-agnostic // request input. -func (e *Extractor) ExtractAddrFrom(input RequestInput, callOpts ...CallOption) (netip.Addr, error) { - extraction, err := e.ExtractFrom(input, callOpts...) +func (e *Extractor) ExtractInputAddr(input Input) (netip.Addr, error) { + extraction, err := e.ExtractInput(input) if err != nil { return netip.Addr{}, err } @@ -128,73 +95,45 @@ func (e *Extractor) ExtractAddrFrom(input RequestInput, callOpts ...CallOption) return extraction.IP, nil } -func (e *Extractor) prepareCall(callOpts ...CallOption) (*Extractor, sourceExtractor, error) { - activeExtractor := e - activeSource := e.source - - if len(callOpts) == 0 { - return activeExtractor, activeSource, nil - } - - effectiveConfig, err := e.config.withCallOptions(callOpts...) - if err != nil { - return nil, nil, fmt.Errorf("invalid call options: %w", err) - } - if effectiveConfig == e.config { - return activeExtractor, activeSource, nil +func (e *Extractor) extractWithSource(source sourceExtractor, r requestView) (Extraction, error) { + if err := r.context().Err(); err != nil { + return Extraction{}, err } - activeExtractor = &Extractor{config: effectiveConfig} - activeExtractor.source = activeExtractor.buildSourceChain(effectiveConfig) - activeSource = activeExtractor.source - - return activeExtractor, activeSource, nil -} - -func (e *Extractor) extractWithSource(source sourceExtractor, ctx context.Context, r *http.Request) (Extraction, error) { - extractionResult, err := source.Extract(ctx, r) + result, err := source.extract(r) if err != nil { - sourceValue := e.getSource(extractionResult, err) - return Extraction{ - Source: sourceValue, - TrustedProxyCount: extractionResult.TrustedProxyCount, - DebugInfo: extractionResult.DebugInfo, - }, err - } - - sourceValue := extractionResult.Source - if !sourceValue.valid() { - sourceValue = source.Source() + fallbackSource := source.sourceInfo() + if !result.Source.valid() { + result.Source = fallbackSource + } + result.Source = e.getSource(result, err) + return result, err } - return Extraction{ - IP: normalizeIP(extractionResult.IP), - Source: sourceValue, - TrustedProxyCount: extractionResult.TrustedProxyCount, - DebugInfo: extractionResult.DebugInfo, - }, nil + return result, nil } func (e *Extractor) extractFromRemoteAddr(remoteAddr string) (Extraction, error) { - result, err := e.extractRemoteAddr(remoteAddr) - if err != nil { - sourceValue := e.getSource(result, err) - return Extraction{ - Source: sourceValue, - TrustedProxyCount: result.TrustedProxyCount, - DebugInfo: result.DebugInfo, - }, err + source := builtinSource(sourceRemoteAddr) + result, failure := remoteAddrExtractor{clientIPPolicy: e.clientIP}.extract(remoteAddr, source) + if failure != nil { + if failure.kind != failureSourceUnavailable { + e.recordInvalidClientIPDisposition(failure.clientIPDisposition) + e.config.metrics.RecordExtractionFailure(source.String()) + } + err := adaptRemoteAddrFailure(failure, source) + result.Source = e.getSource(result, err) + return result, err } - return Extraction{ - IP: normalizeIP(result.IP), - Source: result.Source, - TrustedProxyCount: result.TrustedProxyCount, - DebugInfo: result.DebugInfo, - }, nil + e.config.metrics.RecordExtractionSuccess(source.String()) + return result, nil } -func (e *Extractor) getSource(result extractionResult, err error) Source { +// getSource resolves the authoritative source for a result. +// +// Precedence: error-embedded source > result source > extractor default. +func (e *Extractor) getSource(result Extraction, err error) Source { if err != nil { var sourceErr interface{ SourceValue() Source } if errors.As(err, &sourceErr) { @@ -205,5 +144,5 @@ func (e *Extractor) getSource(result extractionResult, err error) Source { if result.Source.valid() { return result.Source } - return e.source.Source() + return e.source.sourceInfo() } diff --git a/extractor_test.go b/extractor_test.go index e63a384..535a0eb 100644 --- a/extractor_test.go +++ b/extractor_test.go @@ -5,13 +5,15 @@ import ( "errors" "net/http" "net/netip" + "net/textproto" + "net/url" "testing" "github.com/google/go-cmp/cmp" ) func TestExtract_RemoteAddr(t *testing.T) { - extractor, err := New() + extractor, err := New(DefaultConfig()) if err != nil { t.Fatalf("New() error = %v", err) } @@ -125,13 +127,10 @@ func TestExtract_RemoteAddr(t *testing.T) { func TestExtract_WithSourcePriorityAlias_CanonicalizesBuiltIns(t *testing.T) { t.Run("Forwarded alias uses Forwarded parser", func(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(HeaderSource("Forwarded"), SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{HeaderSource("Forwarded"), SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:8080", @@ -153,13 +152,10 @@ func TestExtract_WithSourcePriorityAlias_CanonicalizesBuiltIns(t *testing.T) { }) t.Run("X-Forwarded-For alias combines multiple header lines", func(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(HeaderSource("X-Forwarded-For"), SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{HeaderSource("X-Forwarded-For"), SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:8080", @@ -181,10 +177,9 @@ func TestExtract_WithSourcePriorityAlias_CanonicalizesBuiltIns(t *testing.T) { }) t.Run("Remote-Addr alias maps to RemoteAddr source", func(t *testing.T) { - extractor, err := New(WithSourcePriority(HeaderSource("Remote-Addr"))) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.Sources = []Source{HeaderSource("Remote-Addr")} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "8.8.8.8:8080", @@ -204,13 +199,10 @@ func TestExtract_WithSourcePriorityAlias_CanonicalizesBuiltIns(t *testing.T) { }) t.Run("X_Real_IP alias maps to X-Real-IP source", func(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(HeaderSource("X_Real_IP"), SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{HeaderSource("X_Real_IP"), SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:8080", @@ -232,14 +224,10 @@ func TestExtract_WithSourcePriorityAlias_CanonicalizesBuiltIns(t *testing.T) { } func TestExtract_XForwardedFor(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = mergeUniquePrefixes(LoopbackProxyPrefixes(), mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1"))...) + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) tests := []struct { name string @@ -326,13 +314,10 @@ func TestExtract_XForwardedFor(t *testing.T) { } func TestExtract_Forwarded(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceForwarded, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceForwarded, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) tests := []struct { name string @@ -425,15 +410,12 @@ func TestExtract_Forwarded_WithTrustedProxies(t *testing.T) { t.Fatalf("ParseCIDRs() error = %v", err) } - extractor, err := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(2), - WithSourcePriority(SourceForwarded), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 2 + cfg.Sources = []Source{SourceForwarded} + extractor := mustNewExtractor(t, cfg) tests := []struct { name string @@ -512,60 +494,33 @@ func TestExtract_Forwarded_WithTrustedProxies(t *testing.T) { } } -func TestExtract_ParsesMultipleXFFHeaders_AcrossSecurityModes(t *testing.T) { - tests := []struct { - name string - securityMode SecurityMode - setSecurityMode bool - }{ - {name: "strict_default"}, - {name: "lax", securityMode: SecurityModeLax, setSecurityMode: true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - opts := []Option{ - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - } - if tt.setSecurityMode { - opts = append(opts, WithSecurityMode(tt.securityMode)) - } - - extractor, err := New(opts...) - if err != nil { - t.Fatalf("New() error = %v", err) - } +func TestExtract_ParsesMultipleXFFHeaders(t *testing.T) { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) - req := &http.Request{ - RemoteAddr: "1.1.1.1:8080", - Header: make(http.Header), - } - req.Header.Add("X-Forwarded-For", "8.8.8.8") - req.Header.Add("X-Forwarded-For", "9.9.9.9") + req := &http.Request{RemoteAddr: "1.1.1.1:8080", Header: make(http.Header)} + req.Header.Add("X-Forwarded-For", "8.8.8.8") + req.Header.Add("X-Forwarded-For", "9.9.9.9") - result, err := extractor.Extract(req) - if err != nil || !result.IP.IsValid() { - t.Fatalf("expected extraction to succeed, got error: %v", err) - } - if result.Source != SourceXForwardedFor { - t.Fatalf("source = %q, want %q", result.Source, SourceXForwardedFor) - } - if got, want := result.IP.String(), "9.9.9.9"; got != want { - t.Fatalf("ip = %q, want %q", got, want) - } - }) + result, err := extractor.Extract(req) + if err != nil || !result.IP.IsValid() { + t.Fatalf("expected extraction to succeed, got error: %v", err) + } + if result.Source != SourceXForwardedFor { + t.Fatalf("source = %q, want %q", result.Source, SourceXForwardedFor) + } + if got, want := result.IP.String(), "9.9.9.9"; got != want { + t.Fatalf("ip = %q, want %q", got, want) } } func TestExtract_StrictMode_MalformedForwarded_IsTerminal(t *testing.T) { - extractor, err := New( - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceForwarded, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceForwarded, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "1.1.1.1:8080", @@ -586,44 +541,12 @@ func TestExtract_StrictMode_MalformedForwarded_IsTerminal(t *testing.T) { } } -func TestExtract_SecurityModeLax_AllowsFallbackOnMalformedForwarded(t *testing.T) { - extractor, err := New( - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceForwarded, SourceRemoteAddr), - WithSecurityMode(SecurityModeLax), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := &http.Request{ - RemoteAddr: "1.1.1.1:8080", - Header: make(http.Header), - } - req.Header.Set("Forwarded", "for=\"1.1.1.1") - req.Header.Set("X-Forwarded-For", "8.8.8.8") - - result, err := extractor.Extract(req) - if err != nil || !result.IP.IsValid() { - t.Fatalf("expected SecurityModeLax fallback success, got error: %v", err) - } - if result.Source != SourceRemoteAddr { - t.Fatalf("source = %q, want %q", result.Source, SourceRemoteAddr) - } - if got, want := result.IP.String(), "1.1.1.1"; got != want { - t.Fatalf("ip = %q, want %q", got, want) - } -} - func TestExtract_ChainTooLong_IsTerminalByDefault(t *testing.T) { - extractor, err := New( - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - WithMaxChainLength(2), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + cfg.MaxChainLength = 2 + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "1.1.1.1:8080", @@ -649,15 +572,12 @@ func TestExtract_StrictMode_UntrustedProxy_IsTerminal(t *testing.T) { t.Fatalf("ParseCIDRs() error = %v", err) } - extractor, err := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(3), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 3 + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "8.8.8.8:8080", @@ -677,115 +597,11 @@ func TestExtract_StrictMode_UntrustedProxy_IsTerminal(t *testing.T) { } } -func TestExtract_SecurityModeLax_AllowsFallbackOnUntrustedProxy(t *testing.T) { - cidrs, err := ParseCIDRs("10.0.0.0/8") - if err != nil { - t.Fatalf("ParseCIDRs() error = %v", err) - } - - extractor, err := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(3), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - WithSecurityMode(SecurityModeLax), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := &http.Request{ - RemoteAddr: "8.8.8.8:8080", - Header: make(http.Header), - } - req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1") - - result, err := extractor.Extract(req) - if err != nil || !result.IP.IsValid() { - t.Fatalf("expected SecurityModeLax fallback success, got error: %v", err) - } - if result.Source != SourceRemoteAddr { - t.Fatalf("source = %q, want %q", result.Source, SourceRemoteAddr) - } - if got, want := result.IP.String(), "8.8.8.8"; got != want { - t.Fatalf("ip = %q, want %q", got, want) - } -} - -func TestExtract_SingleHeader_UntrustedProxy_StrictVsLax(t *testing.T) { - cidrs, err := ParseCIDRs("10.0.0.0/8") - if err != nil { - t.Fatalf("ParseCIDRs() error = %v", err) - } - - tests := []struct { - name string - mode SecurityMode - wantErr error - want extractionState - }{ - { - name: "strict mode fails closed", - mode: SecurityModeStrict, - wantErr: ErrUntrustedProxy, - want: extractionState{HasIP: false, Source: SourceXRealIP}, - }, - { - name: "lax mode falls back to remote addr", - mode: SecurityModeLax, - want: extractionState{HasIP: true, IP: "8.8.8.8", Source: SourceRemoteAddr}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - extractor, newErr := New( - WithTrustedProxyPrefixes(cidrs...), - WithSourcePriority(SourceXRealIP, SourceRemoteAddr), - WithSecurityMode(tt.mode), - ) - if newErr != nil { - t.Fatalf("New() error = %v", newErr) - } - - req := &http.Request{ - RemoteAddr: "8.8.8.8:8080", - Header: make(http.Header), - } - req.Header.Set("X-Real-IP", "1.1.1.1") - - got, extractErr := extractor.Extract(req) - - if tt.wantErr != nil { - if !errors.Is(extractErr, tt.wantErr) { - t.Fatalf("error = %v, want %v", extractErr, tt.wantErr) - } - } else if extractErr != nil { - t.Fatalf("Extract() error = %v", extractErr) - } - - gotView := extractionStateOf(got) - - if diff := cmp.Diff(tt.want, gotView); diff != "" { - t.Fatalf("extraction mismatch (-want +got):\n%s", diff) - } - }) - } -} - func TestExtract_XRealIP(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority( - SourceXRealIP, - SourceXForwardedFor, - SourceRemoteAddr, - ), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = mergeUniquePrefixes(LoopbackProxyPrefixes(), mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1"))...) + cfg.Sources = []Source{SourceXRealIP, SourceXForwardedFor, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) tests := []struct { name string @@ -857,44 +673,11 @@ func TestExtract_XRealIP(t *testing.T) { } } -func TestExtract_SecurityModeLax_AllowsFallbackOnInvalidPreferredSingleHeader(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXRealIP, SourceXForwardedFor, SourceRemoteAddr), - WithSecurityMode(SecurityModeLax), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := &http.Request{ - RemoteAddr: "127.0.0.1:8080", - Header: make(http.Header), - } - req.Header.Set("X-Real-IP", "not-an-ip") - req.Header.Set("X-Forwarded-For", "8.8.8.8") - - result, err := extractor.Extract(req) - if err != nil || !result.IP.IsValid() { - t.Fatalf("expected lax-mode fallback to succeed, got error: %v", err) - } - if result.Source != SourceXForwardedFor { - t.Fatalf("source = %q, want %q", result.Source, SourceXForwardedFor) - } - if got, want := result.IP.String(), "8.8.8.8"; got != want { - t.Fatalf("ip = %q, want %q", got, want) - } -} - func TestExtract_StrictMode_DuplicatePreferredSingleHeader_IsTerminal(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXRealIP, SourceXForwardedFor, SourceRemoteAddr), - WithSecurityMode(SecurityModeStrict), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXRealIP, SourceXForwardedFor, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:8080", @@ -916,51 +699,18 @@ func TestExtract_StrictMode_DuplicatePreferredSingleHeader_IsTerminal(t *testing } } -func TestExtract_SecurityModeLax_AllowsFallbackOnDuplicatePreferredSingleHeader(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXRealIP, SourceXForwardedFor, SourceRemoteAddr), - WithSecurityMode(SecurityModeLax), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := &http.Request{ - RemoteAddr: "127.0.0.1:8080", - Header: make(http.Header), - } - req.Header.Add("X-Real-IP", "9.9.9.9") - req.Header.Add("X-Real-IP", "8.8.8.8") - req.Header.Set("X-Forwarded-For", "1.1.1.1") - - result, err := extractor.Extract(req) - if err != nil || !result.IP.IsValid() { - t.Fatalf("expected lax-mode fallback to succeed, got error: %v", err) - } - if result.Source != SourceXForwardedFor { - t.Fatalf("source = %q, want %q", result.Source, SourceXForwardedFor) - } - if got, want := result.IP.String(), "1.1.1.1"; got != want { - t.Fatalf("ip = %q, want %q", got, want) - } -} - func TestExtract_WithTrustedProxies(t *testing.T) { cidrs, err := ParseCIDRs("10.0.0.0/8") if err != nil { t.Fatalf("ParseCIDRs() error = %v", err) } - extractor, err := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(2), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 2 + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) tests := []struct { name string @@ -1049,15 +799,12 @@ func TestExtract_WithTrustedProxies_MinZero_AllowsClientOnlyXFF(t *testing.T) { t.Fatalf("ParseCIDRs() error = %v", err) } - extractor, err := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(0), - WithMaxTrustedProxies(2), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 0 + cfg.MaxTrustedProxies = 2 + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "10.0.0.1:8080", @@ -1081,12 +828,9 @@ func TestExtract_WithTrustedProxies_MinZero_AllowsClientOnlyXFF(t *testing.T) { } func TestExtract_WithAllowPrivateIPs(t *testing.T) { - extractor, err := New( - WithAllowPrivateIPs(true), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.AllowPrivateIPs = true + extractor := mustNewExtractor(t, cfg) tests := []struct { name string @@ -1144,10 +888,9 @@ func TestExtract_WithAllowPrivateIPs(t *testing.T) { func TestExtract_WithAllowedReservedClientPrefixes(t *testing.T) { t.Run("remote reserved allowed", func(t *testing.T) { - extractor, err := New(WithAllowedReservedClientPrefixes(netip.MustParsePrefix("198.51.100.0/24"))) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.AllowedReservedClientPrefixes = []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")} + extractor := mustNewExtractor(t, cfg) req := &http.Request{RemoteAddr: "198.51.100.10:8080", Header: make(http.Header)} result, err := extractor.Extract(req) @@ -1164,27 +907,23 @@ func TestExtract_WithAllowedReservedClientPrefixes(t *testing.T) { }) t.Run("remote reserved not allowlisted", func(t *testing.T) { - extractor, err := New(WithAllowedReservedClientPrefixes(netip.MustParsePrefix("203.0.113.0/24"))) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.AllowedReservedClientPrefixes = []netip.Prefix{netip.MustParsePrefix("203.0.113.0/24")} + extractor := mustNewExtractor(t, cfg) req := &http.Request{RemoteAddr: "198.51.100.10:8080", Header: make(http.Header)} - _, err = extractor.Extract(req) + _, err := extractor.Extract(req) if !errors.Is(err, ErrInvalidIP) { t.Fatalf("error = %v, want ErrInvalidIP", err) } }) t.Run("xff reserved allowed", func(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor), - WithAllowedReservedClientPrefixes(netip.MustParsePrefix("100.64.0.0/10")), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXForwardedFor} + cfg.AllowedReservedClientPrefixes = []netip.Prefix{netip.MustParsePrefix("100.64.0.0/10")} + extractor := mustNewExtractor(t, cfg) req := &http.Request{RemoteAddr: "127.0.0.1:8080", Header: make(http.Header)} req.Header.Set("X-Forwarded-For", "100.64.0.1") @@ -1203,196 +942,27 @@ func TestExtract_WithAllowedReservedClientPrefixes(t *testing.T) { }) t.Run("private still rejected", func(t *testing.T) { - extractor, err := New(WithAllowedReservedClientPrefixes(netip.MustParsePrefix("192.168.0.0/16"))) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := &http.Request{RemoteAddr: "192.168.1.10:8080", Header: make(http.Header)} - _, err = extractor.Extract(req) - if !errors.Is(err, ErrInvalidIP) { - t.Fatalf("error = %v, want ErrInvalidIP", err) - } - }) -} - -func TestExtract_WithCallOptions_AllowedReservedClientPrefixes(t *testing.T) { - extractor, err := New() - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := &http.Request{RemoteAddr: "198.51.100.10:8080", Header: make(http.Header)} - - _, err = extractor.Extract(req) - if !errors.Is(err, ErrInvalidIP) { - t.Fatalf("error = %v, want ErrInvalidIP", err) - } - - result, err := extractor.Extract(req, WithCallAllowedReservedClientPrefixes(netip.MustParsePrefix("198.51.100.0/24"))) - if err != nil { - t.Fatalf("Extract() with call option error = %v", err) - } - - if got, want := result.IP, netip.MustParseAddr("198.51.100.10"); got != want { - t.Fatalf("IP = %s, want %s", got, want) - } - if got, want := result.Source, SourceRemoteAddr; got != want { - t.Fatalf("Source = %q, want %q", got, want) - } -} - -func TestExtract_WithCallOptions_RuntimeOverrides(t *testing.T) { - t.Run("trusted proxy prefixes override trust decision", func(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := &http.Request{RemoteAddr: "10.0.0.1:8080", Header: make(http.Header)} - req.Header.Set("X-Forwarded-For", "8.8.8.8") - - _, err = extractor.Extract(req) - if !errors.Is(err, ErrUntrustedProxy) { - t.Fatalf("error = %v, want ErrUntrustedProxy", err) - } - - result, err := extractor.Extract(req, - WithCallTrustedProxyPrefixes(netip.MustParsePrefix("10.0.0.0/8")), - ) - if err != nil { - t.Fatalf("Extract() with call option error = %v", err) - } - - if got, want := result.IP, netip.MustParseAddr("8.8.8.8"); got != want { - t.Fatalf("IP = %s, want %s", got, want) - } - if got, want := result.Source, SourceXForwardedFor; got != want { - t.Fatalf("Source = %q, want %q", got, want) - } - }) - - t.Run("proxy count bounds override validation", func(t *testing.T) { - extractor, err := New( - WithTrustedProxyPrefixes(netip.MustParsePrefix("10.0.0.0/8")), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := &http.Request{RemoteAddr: "10.0.0.2:8080", Header: make(http.Header)} - req.Header.Set("X-Forwarded-For", "8.8.8.8, 10.0.0.1, 10.0.0.2") - - _, err = extractor.Extract(req, WithCallMinTrustedProxies(3)) - if !errors.Is(err, ErrTooFewTrustedProxies) { - t.Fatalf("min-trusted error = %v, want ErrTooFewTrustedProxies", err) - } - - _, err = extractor.Extract(req, WithCallMaxTrustedProxies(1)) - if !errors.Is(err, ErrTooManyTrustedProxies) { - t.Fatalf("max-trusted error = %v, want ErrTooManyTrustedProxies", err) - } - }) - - t.Run("max chain length override applies at extraction time", func(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := &http.Request{RemoteAddr: "127.0.0.1:8080", Header: make(http.Header)} - req.Header.Set("X-Forwarded-For", "8.8.8.8, 9.9.9.9") - - _, err = extractor.Extract(req, WithCallMaxChainLength(1)) - if !errors.Is(err, ErrChainTooLong) { - t.Fatalf("error = %v, want ErrChainTooLong", err) - } - }) - - t.Run("chain selection and debug info can both be overridden", func(t *testing.T) { - extractor, err := New( - WithTrustedProxyPrefixes(netip.MustParsePrefix("173.245.48.0/20")), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := &http.Request{RemoteAddr: "173.245.48.5:443", Header: make(http.Header)} - req.Header.Set("X-Forwarded-For", "1.1.1.1, 8.8.8.8, 173.245.48.5") - - result, err := extractor.Extract(req) - if err != nil { - t.Fatalf("Extract() error = %v", err) - } - if got, want := result.IP.String(), "8.8.8.8"; got != want { - t.Fatalf("default IP = %q, want %q", got, want) - } - if result.DebugInfo != nil { - t.Fatal("default extraction should not include debug info") - } - - result, err = extractor.Extract(req, - WithCallChainSelection(LeftmostUntrustedIP), - WithCallDebugInfo(true), - ) - if err != nil { - t.Fatalf("Extract() with call options error = %v", err) - } - if got, want := result.IP.String(), "1.1.1.1"; got != want { - t.Fatalf("override IP = %q, want %q", got, want) - } - if result.DebugInfo == nil { - t.Fatal("override extraction should include debug info") - } - }) - - t.Run("allow private IPs can be enabled per call", func(t *testing.T) { - extractor, err := New() - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.AllowedReservedClientPrefixes = []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")} + extractor := mustNewExtractor(t, cfg) req := &http.Request{RemoteAddr: "192.168.1.10:8080", Header: make(http.Header)} - - _, err = extractor.Extract(req) + _, err := extractor.Extract(req) if !errors.Is(err, ErrInvalidIP) { t.Fatalf("error = %v, want ErrInvalidIP", err) } - - result, err := extractor.Extract(req, WithCallAllowPrivateIPs(true)) - if err != nil { - t.Fatalf("Extract() with call option error = %v", err) - } - - if got, want := result.IP, netip.MustParseAddr("192.168.1.10"); got != want { - t.Fatalf("IP = %s, want %s", got, want) - } - if got, want := result.Source, SourceRemoteAddr; got != want { - t.Fatalf("Source = %q, want %q", got, want) - } }) } func TestExtract_WithDebugInfo(t *testing.T) { cidrs, _ := ParseCIDRs("10.0.0.0/8") - extractor, err := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(2), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - WithDebugInfo(true), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 2 + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + cfg.DebugInfo = true + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "10.0.0.1:8080", @@ -1425,13 +995,10 @@ func TestExtract_WithDebugInfo(t *testing.T) { func TestExtract_ErrorTypes(t *testing.T) { t.Run("InvalidForwardedHeader", func(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceForwarded), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceForwarded} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:8080", Header: make(http.Header), @@ -1455,13 +1022,10 @@ func TestExtract_ErrorTypes(t *testing.T) { }) t.Run("MultipleXFFHeaders_AreCombined", func(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:8080", Header: make(http.Header), @@ -1485,13 +1049,10 @@ func TestExtract_ErrorTypes(t *testing.T) { }) t.Run("MultipleSingleIPHeadersError", func(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXRealIP), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXRealIP} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:8080", Header: make(http.Header), @@ -1512,8 +1073,8 @@ func TestExtract_ErrorTypes(t *testing.T) { if multipleHeadersErr.HeaderCount != 2 { t.Errorf("HeaderCount = %d, want 2", multipleHeadersErr.HeaderCount) } - if multipleHeadersErr.HeaderName != "X-Real-IP" { - t.Errorf("HeaderName = %q, want %q", multipleHeadersErr.HeaderName, "X-Real-IP") + if multipleHeadersErr.HeaderName != "X-Real-Ip" { + t.Errorf("HeaderName = %q, want %q", multipleHeadersErr.HeaderName, "X-Real-Ip") } } @@ -1524,15 +1085,12 @@ func TestExtract_ErrorTypes(t *testing.T) { t.Run("ProxyValidationError", func(t *testing.T) { cidrs, _ := ParseCIDRs("10.0.0.0/8") - extractor, err := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(2), - WithMaxTrustedProxies(3), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 2 + cfg.MaxTrustedProxies = 3 + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "10.0.0.1:8080", @@ -1540,7 +1098,7 @@ func TestExtract_ErrorTypes(t *testing.T) { } req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1") - _, err = extractor.Extract(req) + _, err := extractor.Extract(req) if !errors.Is(err, ErrTooFewTrustedProxies) { t.Errorf("Expected error to wrap ErrTooFewTrustedProxies, got %v", err) } @@ -1559,20 +1117,17 @@ func TestExtract_ErrorTypes(t *testing.T) { }) t.Run("InvalidIPError", func(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "127.0.0.1:8080", Header: make(http.Header), } req.Header.Set("X-Forwarded-For", "192.168.1.1") - _, err = extractor.Extract(req) + _, err := extractor.Extract(req) var invalidIPErr *InvalidIPError if !errors.As(err, &invalidIPErr) { @@ -1582,7 +1137,7 @@ func TestExtract_ErrorTypes(t *testing.T) { } func TestExtract_IPv4MappedIPv6(t *testing.T) { - extractor := mustNewExtractor(t) + extractor := mustNewExtractor(t, DefaultConfig()) req := &http.Request{ RemoteAddr: "[::ffff:1.1.1.1]:8080", @@ -1606,7 +1161,7 @@ func TestExtract_IPv4MappedIPv6(t *testing.T) { } func TestExtract_Concurrent(t *testing.T) { - extractor, err := New() + extractor, err := New(DefaultConfig()) if err != nil { t.Fatalf("New() error = %v", err) } @@ -1648,7 +1203,7 @@ func TestExtract_Concurrent(t *testing.T) { type contextKey string func TestExtract_ContextPropagation(t *testing.T) { - extractor := mustNewExtractor(t) + extractor := mustNewExtractor(t, DefaultConfig()) ctx := context.WithValue(context.Background(), contextKey("test-key"), "test-value") req := &http.Request{ @@ -1665,7 +1220,7 @@ func TestExtract_ContextPropagation(t *testing.T) { } func TestExtract_NewAPI_Methods(t *testing.T) { - extractor, err := New() + extractor, err := New(DefaultConfig()) if err != nil { t.Fatalf("New() error = %v", err) } @@ -1730,102 +1285,8 @@ func TestExtract_NewAPI_Methods(t *testing.T) { } } -func TestExtract_WithCallOptions_LastWins(t *testing.T) { - extractor, err := New( - WithTrustedProxyAddrs(netip.MustParseAddr("127.0.0.1")), - WithSourcePriority(HeaderSource("X-Frontend-IP"), SourceXForwardedFor, SourceRemoteAddr), - WithSecurityMode(SecurityModeStrict), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := &http.Request{RemoteAddr: "127.0.0.1:8080", Header: make(http.Header)} - req.Header.Set("X-Frontend-IP", "not-an-ip") - req.Header.Set("X-Forwarded-For", "8.8.8.8") - - tests := []struct { - name string - callOpts []CallOption - wantErr bool - want struct { - IP string - Source Source - } - }{ - { - name: "strict default fails", - callOpts: nil, - wantErr: true, - }, - { - name: "lax call options succeed", - callOpts: []CallOption{ - WithCallSecurityMode(SecurityModeStrict), - WithCallSecurityMode(SecurityModeLax), - }, - want: struct { - IP string - Source Source - }{IP: "8.8.8.8", Source: SourceXForwardedFor}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := extractor.Extract(req, tt.callOpts...) - if tt.wantErr { - if err == nil { - t.Fatal("Extract() error = nil, want non-nil") - } - return - } - - if err != nil { - t.Fatalf("Extract() error = %v", err) - } - - gotView := struct { - IP string - Source Source - }{IP: got.IP.String(), Source: got.Source} - - if diff := cmp.Diff(tt.want, gotView); diff != "" { - t.Fatalf("extraction mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestPrepareCall_NoEffectiveChanges_ReusesExtractorAndSource(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - WithSecurityMode(SecurityModeStrict), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - activeExtractor, activeSource, err := extractor.prepareCall( - WithCallSecurityMode(SecurityModeStrict), - WithCallSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("prepareCall() error = %v", err) - } - - if activeExtractor != extractor { - t.Fatal("prepareCall() should reuse extractor when call options do not change policy") - } - - if activeSource != extractor.source { - t.Fatal("prepareCall() should reuse source chain when call options do not change policy") - } -} - func TestExtract_NilRequest(t *testing.T) { - extractor, err := New() + extractor, err := New(DefaultConfig()) if err != nil { t.Fatalf("New() error = %v", err) } @@ -1847,11 +1308,10 @@ func TestExtract_NilRequest(t *testing.T) { func TestExtract_ErrorSourceReporting(t *testing.T) { t.Run("source-aware error keeps source", func(t *testing.T) { - extractor := mustNewExtractor( - t, - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceXRealIP), - ) + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceXRealIP} + extractor := mustNewExtractor(t, cfg) req := &http.Request{ RemoteAddr: "1.1.1.1:8080", @@ -1874,7 +1334,7 @@ func TestExtract_ErrorSourceReporting(t *testing.T) { }) t.Run("non source-aware error leaves source empty", func(t *testing.T) { - extractor := mustNewExtractor(t) + extractor := mustNewExtractor(t, DefaultConfig()) ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -1898,20 +1358,238 @@ func TestExtract_ErrorSourceReporting(t *testing.T) { }) } -func TestNew_OptionsImmutability(t *testing.T) { - trusted := []netip.Prefix{netip.MustParsePrefix("127.0.0.0/8")} - priority := []Source{SourceXForwardedFor, SourceRemoteAddr} +type panicOnNilHeaderProvider struct{} + +func (p *panicOnNilHeaderProvider) Values(string) []string { + if p == nil { + panic("nil header provider should not be called") + } + + return nil +} + +func TestExtractInput_ParityWithExtract(t *testing.T) { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) + + tests := []struct { + name string + headers []string + remote string + }{ + {name: "remote_addr_only", remote: "8.8.8.8:8080"}, + {name: "xff_success", remote: "1.1.1.1:8080", headers: []string{"8.8.8.8"}}, + {name: "duplicate_xff", remote: "1.1.1.1:8080", headers: []string{"8.8.8.8", "9.9.9.9"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.WithValue(context.Background(), contextKey("trace_id"), "trace-123") + req := (&http.Request{ + RemoteAddr: tt.remote, + Header: make(http.Header), + URL: &url.URL{Path: "/parity"}, + }).WithContext(ctx) + + for _, value := range tt.headers { + req.Header.Add("X-Forwarded-For", value) + } + + httpExtraction, httpErr := extractor.Extract(req) + inputExtraction, inputErr := extractor.ExtractInput(Input{ + Context: req.Context(), + RemoteAddr: req.RemoteAddr, + Path: req.URL.Path, + Headers: req.Header, + }) + + if httpErr != nil { + t.Fatalf("Extract() error = %v", httpErr) + } + if inputErr != nil { + t.Fatalf("ExtractInput() error = %v", inputErr) + } + + if inputExtraction != httpExtraction { + t.Fatalf("extraction mismatch: ExtractInput=%+v Extract=%+v", inputExtraction, httpExtraction) + } + }) + } +} + +func TestExtractInput_HeaderValuesFunc(t *testing.T) { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{HeaderSource("CF-Connecting-IP"), SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) + + cfHeader := textproto.CanonicalMIMEHeaderKey("CF-Connecting-IP") + requestedHeaders := make([]string, 0, 1) + headers := HeaderValuesFunc(func(name string) []string { + requestedHeaders = append(requestedHeaders, name) + if name == cfHeader { + return []string{"9.9.9.9"} + } + return nil + }) - extractor, err := New( - WithTrustedProxyPrefixes(trusted...), - WithMinTrustedProxies(0), - WithMaxTrustedProxies(0), - WithSourcePriority(priority...), - ) + extraction, err := extractor.ExtractInput(Input{RemoteAddr: "127.0.0.1:8080", Headers: headers}) + if err != nil { + t.Fatalf("ExtractInput() error = %v", err) + } + + if got, want := extraction.Source, HeaderSource("CF-Connecting-IP"); got != want { + t.Fatalf("source = %q, want %q", got, want) + } + if got, want := extraction.IP, netip.MustParseAddr("9.9.9.9"); got != want { + t.Fatalf("ip = %s, want %s", got, want) + } + if len(requestedHeaders) != 1 || requestedHeaders[0] != cfHeader { + t.Fatalf("requested headers = %v, want [%q]", requestedHeaders, cfHeader) + } +} + +func TestExtractInput_RemoteAddrOnlyDoesNotRequestHeaders(t *testing.T) { + extractor, err := New(DefaultConfig()) if err != nil { t.Fatalf("New() error = %v", err) } + requested := 0 + input := Input{ + RemoteAddr: "8.8.8.8:8080", + Headers: HeaderValuesFunc(func(string) []string { + requested++ + return nil + }), + } + + extraction, err := extractor.ExtractInput(input) + if err != nil { + t.Fatalf("ExtractInput() error = %v", err) + } + if got, want := extraction.IP.String(), "8.8.8.8"; got != want { + t.Fatalf("ip = %q, want %q", got, want) + } + if requested != 0 { + t.Fatalf("header provider called %d times, want 0", requested) + } +} + +func TestExtractInput_TypedNilHeaderProviderTreatedAsAbsent(t *testing.T) { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("8.8.8.8")) + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) + + var nilHTTPHeader *http.Header + var nilProvider *panicOnNilHeaderProvider + + tests := []struct { + name string + headers HeaderValues + }{ + {name: "typed_nil_http_header_pointer", headers: nilHTTPHeader}, + {name: "typed_nil_custom_provider_pointer", headers: nilProvider}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + extraction, extractErr := extractor.ExtractInput(Input{RemoteAddr: "8.8.8.8:8080", Headers: tt.headers}) + if extractErr != nil { + t.Fatalf("ExtractInput() error = %v", extractErr) + } + if got, want := extraction.Source, SourceRemoteAddr; got != want { + t.Fatalf("source = %q, want %q", got, want) + } + if got, want := extraction.IP, netip.MustParseAddr("8.8.8.8"); got != want { + t.Fatalf("ip = %s, want %s", got, want) + } + }) + } +} + +func TestExtractInput_RemoteAddrOnly_RespectsCanceledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + extractor, err := New(DefaultConfig()) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + _, extractErr := extractor.ExtractInput(Input{Context: ctx, RemoteAddr: "8.8.8.8:8080"}) + if !errors.Is(extractErr, context.Canceled) { + t.Fatalf("error = %v, want context.Canceled", extractErr) + } +} + +func TestExtractInput_CanceledContext_DoesNotRequestHeaders(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) + + requested := 0 + _, extractErr := extractor.ExtractInput(Input{ + Context: ctx, + RemoteAddr: "1.1.1.1:8080", + Headers: HeaderValuesFunc(func(string) []string { + requested++ + return []string{"8.8.8.8"} + }), + }) + if !errors.Is(extractErr, context.Canceled) { + t.Fatalf("error = %v, want context.Canceled", extractErr) + } + if requested != 0 { + t.Fatalf("header provider called %d times, want 0", requested) + } +} + +func TestExtractInput_NilContextDefaultsBackground(t *testing.T) { + input := Input{RemoteAddr: "8.8.8.8:8080"} + extractor, err := New(DefaultConfig()) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + extraction, err := extractor.ExtractInput(input) + if err != nil { + t.Fatalf("ExtractInput() error = %v", err) + } + if got, want := extraction.IP.String(), "8.8.8.8"; got != want { + t.Fatalf("IP = %q, want %q", got, want) + } + if got, want := extraction.Source, SourceRemoteAddr; got != want { + t.Fatalf("Source = %q, want %q", got, want) + } + + addr, err := extractor.ExtractInputAddr(input) + if err != nil { + t.Fatalf("ExtractInputAddr() error = %v", err) + } + if got, want := addr.String(), "8.8.8.8"; got != want { + t.Fatalf("IP = %q, want %q", got, want) + } +} + +func TestNew_ConfigInputImmutability(t *testing.T) { + trusted := []netip.Prefix{netip.MustParsePrefix("127.0.0.0/8")} + priority := []Source{SourceXForwardedFor, SourceRemoteAddr} + + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = clonePrefixes(trusted) + cfg.MinTrustedProxies = 0 + cfg.MaxTrustedProxies = 0 + cfg.Sources = cloneSources(priority) + extractor := mustNewExtractor(t, cfg) + trusted[0] = netip.MustParsePrefix("10.0.0.0/8") priority[0] = SourceRemoteAddr diff --git a/forwarded_test.go b/forwarded_test.go deleted file mode 100644 index fadbfab..0000000 --- a/forwarded_test.go +++ /dev/null @@ -1,206 +0,0 @@ -package clientip - -import ( - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestParseForwardedValues(t *testing.T) { - tests := []struct { - name string - values []string - want []string - wantErr error - }{ - { - name: "single for value", - values: []string{"for=1.1.1.1"}, - want: []string{"1.1.1.1"}, - }, - { - name: "case-insensitive parameter name", - values: []string{"For=1.1.1.1"}, - want: []string{"1.1.1.1"}, - }, - { - name: "multiple elements in one header", - values: []string{"for=1.1.1.1, for=8.8.8.8"}, - want: []string{"1.1.1.1", "8.8.8.8"}, - }, - { - name: "multiple header lines", - values: []string{"for=1.1.1.1", "for=8.8.8.8"}, - want: []string{"1.1.1.1", "8.8.8.8"}, - }, - { - name: "parameters with semicolons", - values: []string{"for=1.1.1.1;proto=https;by=10.0.0.1"}, - want: []string{"1.1.1.1"}, - }, - { - name: "quoted IPv6 and port", - values: []string{"for=\"[2606:4700:4700::1]:8080\""}, - want: []string{"[2606:4700:4700::1]:8080"}, - }, - { - name: "quoted comma is not treated as element delimiter", - values: []string{"for=\"1.1.1.1,8.8.8.8\";proto=https"}, - want: []string{"1.1.1.1,8.8.8.8"}, - }, - { - name: "quoted semicolon is not treated as param delimiter", - values: []string{"for=\"1.1.1.1;edge\";proto=https"}, - want: []string{"1.1.1.1;edge"}, - }, - { - name: "escaped quote remains inside quoted value", - values: []string{`for="1.1.1.1\";edge";proto=https`}, - want: []string{`1.1.1.1";edge`}, - }, - { - name: "ignores element without for parameter", - values: []string{"proto=https;by=10.0.0.1, for=8.8.8.8"}, - want: []string{"8.8.8.8"}, - }, - { - name: "invalid parameter format", - values: []string{"for"}, - wantErr: ErrInvalidForwardedHeader, - }, - { - name: "unterminated quoted string", - values: []string{"for=\"1.1.1.1"}, - wantErr: ErrInvalidForwardedHeader, - }, - { - name: "duplicate for parameter", - values: []string{"for=1.1.1.1;for=8.8.8.8"}, - wantErr: ErrInvalidForwardedHeader, - }, - { - name: "trailing escape in quoted value", - values: []string{`for="1.1.1.1\`}, - wantErr: ErrInvalidForwardedHeader, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - extractor := mustNewExtractor(t) - got, err := extractor.parseForwardedValues(tt.values) - - if tt.wantErr != nil { - if !errorContains(err, tt.wantErr) { - t.Fatalf("parseForwardedValues() error = %v, want %v", err, tt.wantErr) - } - return - } - - if err != nil { - t.Fatalf("parseForwardedValues() error = %v, want nil", err) - } - - if diff := cmp.Diff(tt.want, got); diff != "" { - t.Fatalf("parseForwardedValues() mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestParseForwardedValues_MaxChainLength(t *testing.T) { - extractor := mustNewExtractor(t, WithMaxChainLength(2)) - - _, err := extractor.parseForwardedValues([]string{"for=1.1.1.1, for=2.2.2.2, for=3.3.3.3"}) - if !errorContains(err, ErrChainTooLong) { - t.Fatalf("parseForwardedValues() error = %v, want ErrChainTooLong", err) - } -} - -func TestParseForwardedValues_MalformedParameterMatrix(t *testing.T) { - extractor := mustNewExtractor(t) - - tests := []struct { - name string - values []string - }{ - {name: "empty parameter key", values: []string{"=1.1.1.1"}}, - {name: "empty for value", values: []string{"for="}}, - {name: "empty quoted for value", values: []string{`for=""`}}, - {name: "invalid quoted for value suffix", values: []string{`for="1.1.1.1"extra`}}, - {name: "non for parameter missing equals", values: []string{"for=1.1.1.1;proto"}}, - {name: "non for parameter empty key", values: []string{"for=1.1.1.1;=https"}}, - {name: "non for parameter empty value", values: []string{"for=1.1.1.1;proto="}}, - {name: "unterminated quoted value across params", values: []string{"for=1.1.1.1;proto=\"https"}}, - {name: "unbalanced quotes in element", values: []string{`for="a"b"`}}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, parseErr := extractor.parseForwardedValues(tt.values) - - got := struct { - HasErr bool - IsInvalid bool - }{ - HasErr: parseErr != nil, - IsInvalid: errorContains(parseErr, ErrInvalidForwardedHeader), - } - - want := struct { - HasErr bool - IsInvalid bool - }{ - HasErr: true, - IsInvalid: true, - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Fatalf("parseForwardedValues() mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestParseForwardedForValue(t *testing.T) { - tests := []struct { - name string - input string - want string - wantErr bool - }{ - {name: "unquoted token", input: "1.1.1.1", want: "1.1.1.1"}, - {name: "quoted token", input: `"1.1.1.1"`, want: "1.1.1.1"}, - {name: "quoted token with surrounding spaces", input: ` "1.1.1.1" `, want: "1.1.1.1"}, - {name: "escaped quote in quoted token", input: `"1.1.1.1\\\"edge"`, want: `1.1.1.1\"edge`}, - {name: "empty input", input: "", wantErr: true}, - {name: "spaces only", input: " ", wantErr: true}, - {name: "unterminated quote", input: `"1.1.1.1`, wantErr: true}, - {name: "unexpected inner quote", input: `"a"b"`, wantErr: true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := parseForwardedForValue(tt.input) - - gotView := struct { - Value string - HasErr bool - }{ - Value: got, - HasErr: err != nil, - } - wantView := struct { - Value string - HasErr bool - }{ - Value: tt.want, - HasErr: tt.wantErr, - } - - if diff := cmp.Diff(wantView, gotView); diff != "" { - t.Fatalf("parseForwardedForValue() mismatch (-want +got):\n%s", diff) - } - }) - } -} diff --git a/input.go b/input.go new file mode 100644 index 0000000..12125bb --- /dev/null +++ b/input.go @@ -0,0 +1,51 @@ +package clientip + +import "context" + +// HeaderValues provides access to request header values by name. +// +// Implementations should return one slice entry per received header line. +// Single-IP sources rely on per-line values to detect duplicates, and chain +// sources preserve wire order across repeated lines. +// +// Header names are requested in canonical MIME format (for example +// "X-Forwarded-For"). +// +// net/http's http.Header satisfies this interface directly. +type HeaderValues interface { + Values(name string) []string +} + +// HeaderValuesFunc adapts a function to the HeaderValues interface. +type HeaderValuesFunc func(name string) []string + +// Values implements HeaderValues. +func (f HeaderValuesFunc) Values(name string) []string { + if f == nil { + return nil + } + + return f(name) +} + +// Input provides framework-agnostic request data for extraction. +// +// Context defaults to context.Background() when nil. +// +// For Headers, preserve repeated header lines as separate values for each +// header name (for example two X-Forwarded-For lines should yield a slice with +// length 2, and two X-Real-IP lines should also yield length 2). +type Input struct { + Context context.Context + RemoteAddr string + Path string + Headers HeaderValues +} + +func requestInputContext(input Input) context.Context { + if input.Context == nil { + return context.Background() + } + + return input.Context +} diff --git a/request_input_test.go b/input_test.go similarity index 58% rename from request_input_test.go rename to input_test.go index a22779f..0f330e8 100644 --- a/request_input_test.go +++ b/input_test.go @@ -28,34 +28,34 @@ func (p *panicIfCalledHeaderProvider) Values(string) []string { return nil } -func TestRequestFromInput_HeaderProviderPaths(t *testing.T) { +func TestRequestViewFromInput_HeaderProviderPaths(t *testing.T) { var nilHTTPHeader *http.Header var nilHeaderFunc HeaderValuesFunc var nilCustomProvider *panicIfCalledHeaderProvider tests := []struct { name string - sourceHeaderKeys []string - newInput func(t *testing.T) (RequestInput, func() []string) + headerName string + newInput func(t *testing.T) (Input, func() []string) wantRemoteAddr string wantPath string - wantHeaders http.Header + wantHeaderValues []string wantProviderCalls []string }{ { - name: "nil headers treated as absent", - sourceHeaderKeys: []string{"X-Forwarded-For"}, - newInput: func(*testing.T) (RequestInput, func() []string) { - return RequestInput{RemoteAddr: "8.8.8.8:443", Path: "/nil-headers"}, nil + name: "nil headers treated as absent", + headerName: "X-Forwarded-For", + newInput: func(*testing.T) (Input, func() []string) { + return Input{RemoteAddr: "8.8.8.8:443", Path: "/nil-headers"}, nil }, wantRemoteAddr: "8.8.8.8:443", wantPath: "/nil-headers", }, { - name: "http.Header passthrough", - sourceHeaderKeys: []string{"X-Forwarded-For"}, - newInput: func(*testing.T) (RequestInput, func() []string) { - return RequestInput{ + name: "http.Header passthrough", + headerName: "X-Forwarded-For", + newInput: func(*testing.T) (Input, func() []string) { + return Input{ RemoteAddr: "1.1.1.1:80", Path: "/http-header", Headers: http.Header{ @@ -64,29 +64,26 @@ func TestRequestFromInput_HeaderProviderPaths(t *testing.T) { }, }, nil }, - wantRemoteAddr: "1.1.1.1:80", - wantPath: "/http-header", - wantHeaders: http.Header{ - "X-Forwarded-For": {"8.8.8.8", "9.9.9.9"}, - "X-Real-IP": {"4.4.4.4"}, - }, + wantRemoteAddr: "1.1.1.1:80", + wantPath: "/http-header", + wantHeaderValues: []string{"8.8.8.8", "9.9.9.9"}, }, { - name: "*http.Header passthrough", - sourceHeaderKeys: []string{"Forwarded"}, - newInput: func(*testing.T) (RequestInput, func() []string) { + name: "*http.Header passthrough", + headerName: "Forwarded", + newInput: func(*testing.T) (Input, func() []string) { h := http.Header{"Forwarded": {"for=1.1.1.1"}} - return RequestInput{RemoteAddr: "2.2.2.2:80", Path: "/header-pointer", Headers: &h}, nil + return Input{RemoteAddr: "2.2.2.2:80", Path: "/header-pointer", Headers: &h}, nil }, - wantRemoteAddr: "2.2.2.2:80", - wantPath: "/header-pointer", - wantHeaders: http.Header{"Forwarded": {"for=1.1.1.1"}}, + wantRemoteAddr: "2.2.2.2:80", + wantPath: "/header-pointer", + wantHeaderValues: []string{"for=1.1.1.1"}, }, { - name: "typed nil *http.Header treated as absent", - sourceHeaderKeys: []string{"X-Forwarded-For"}, - newInput: func(*testing.T) (RequestInput, func() []string) { - return RequestInput{ + name: "typed nil *http.Header treated as absent", + headerName: "X-Forwarded-For", + newInput: func(*testing.T) (Input, func() []string) { + return Input{ RemoteAddr: "3.3.3.3:80", Path: "/typed-nil-header", Headers: nilHTTPHeader, @@ -96,10 +93,10 @@ func TestRequestFromInput_HeaderProviderPaths(t *testing.T) { wantPath: "/typed-nil-header", }, { - name: "nil HeaderValuesFunc treated as absent", - sourceHeaderKeys: []string{"X-Forwarded-For"}, - newInput: func(*testing.T) (RequestInput, func() []string) { - return RequestInput{ + name: "nil HeaderValuesFunc treated as absent", + headerName: "X-Forwarded-For", + newInput: func(*testing.T) (Input, func() []string) { + return Input{ RemoteAddr: "4.4.4.4:80", Path: "/nil-header-func", Headers: nilHeaderFunc, @@ -109,50 +106,47 @@ func TestRequestFromInput_HeaderProviderPaths(t *testing.T) { wantPath: "/nil-header-func", }, { - name: "single-key provider path", - sourceHeaderKeys: []string{"X-Forwarded-For"}, - newInput: func(*testing.T) (RequestInput, func() []string) { + name: "provider path forwards header lookups", + headerName: "X-Forwarded-For", + newInput: func(*testing.T) (Input, func() []string) { provider := &trackingHeaderProvider{ values: map[string][]string{ "X-Forwarded-For": {"8.8.8.8"}, "X-Real-IP": {"7.7.7.7"}, }, } - return RequestInput{RemoteAddr: "5.5.5.5:80", Path: "/single-key", Headers: provider}, func() []string { + return Input{RemoteAddr: "5.5.5.5:80", Path: "/single-key", Headers: provider}, func() []string { return provider.calls } }, wantRemoteAddr: "5.5.5.5:80", wantPath: "/single-key", - wantHeaders: http.Header{"X-Forwarded-For": {"8.8.8.8"}}, + wantHeaderValues: []string{"8.8.8.8"}, wantProviderCalls: []string{"X-Forwarded-For"}, }, { - name: "multiple-key provider path", - sourceHeaderKeys: []string{"Forwarded", "X-Forwarded-For", "X-Real-IP"}, - newInput: func(*testing.T) (RequestInput, func() []string) { + name: "provider path does not prefetch unrelated headers", + headerName: "X-Forwarded-For", + newInput: func(*testing.T) (Input, func() []string) { provider := &trackingHeaderProvider{ values: map[string][]string{ "X-Forwarded-For": {"8.8.8.8", "9.9.9.9"}, "X-Real-IP": {"7.7.7.7"}, }, } - return RequestInput{RemoteAddr: "6.6.6.6:80", Path: "/multiple-keys", Headers: provider}, func() []string { + return Input{RemoteAddr: "6.6.6.6:80", Path: "/multiple-keys", Headers: provider}, func() []string { return provider.calls } }, - wantRemoteAddr: "6.6.6.6:80", - wantPath: "/multiple-keys", - wantHeaders: http.Header{ - "X-Forwarded-For": {"8.8.8.8", "9.9.9.9"}, - "X-Real-IP": {"7.7.7.7"}, - }, - wantProviderCalls: []string{"Forwarded", "X-Forwarded-For", "X-Real-IP"}, + wantRemoteAddr: "6.6.6.6:80", + wantPath: "/multiple-keys", + wantHeaderValues: []string{"8.8.8.8", "9.9.9.9"}, + wantProviderCalls: []string{"X-Forwarded-For"}, }, { - name: "HeaderValuesFunc provider path", - sourceHeaderKeys: []string{"X-Forwarded-For"}, - newInput: func(*testing.T) (RequestInput, func() []string) { + name: "HeaderValuesFunc provider path", + headerName: "X-Forwarded-For", + newInput: func(*testing.T) (Input, func() []string) { calls := make([]string, 0, 1) headers := HeaderValuesFunc(func(name string) []string { calls = append(calls, name) @@ -161,38 +155,37 @@ func TestRequestFromInput_HeaderProviderPaths(t *testing.T) { } return nil }) - return RequestInput{RemoteAddr: "6.6.6.7:80", Path: "/header-func", Headers: headers}, func() []string { + return Input{RemoteAddr: "6.6.6.7:80", Path: "/header-func", Headers: headers}, func() []string { return calls } }, wantRemoteAddr: "6.6.6.7:80", wantPath: "/header-func", - wantHeaders: http.Header{"X-Forwarded-For": {"8.8.8.8"}}, + wantHeaderValues: []string{"8.8.8.8"}, wantProviderCalls: []string{"X-Forwarded-For"}, }, { - name: "no source keys skips provider", - sourceHeaderKeys: nil, - newInput: func(t *testing.T) (RequestInput, func() []string) { + name: "view only calls provider when Values is used", + headerName: "X-Forwarded-For", + newInput: func(t *testing.T) (Input, func() []string) { calls := make([]string, 0, 1) headers := HeaderValuesFunc(func(name string) []string { calls = append(calls, name) - t.Fatalf("header provider should not be called when sourceHeaderKeys is empty (called with %q)", name) return nil }) - return RequestInput{RemoteAddr: "7.7.7.7:80", Path: "/skip-provider", Headers: headers}, func() []string { + return Input{RemoteAddr: "7.7.7.7:80", Path: "/skip-provider", Headers: headers}, func() []string { return calls } }, wantRemoteAddr: "7.7.7.7:80", wantPath: "/skip-provider", - wantProviderCalls: []string{}, + wantProviderCalls: []string{"X-Forwarded-For"}, }, { - name: "typed nil custom provider treated as absent", - sourceHeaderKeys: []string{"X-Forwarded-For"}, - newInput: func(*testing.T) (RequestInput, func() []string) { - return RequestInput{ + name: "typed nil custom provider treated as absent", + headerName: "X-Forwarded-For", + newInput: func(*testing.T) (Input, func() []string) { + return Input{ RemoteAddr: "9.9.9.9:80", Path: "/typed-nil-provider", Headers: nilCustomProvider, @@ -206,28 +199,26 @@ func TestRequestFromInput_HeaderProviderPaths(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { input, providerCalls := tt.newInput(t) - req := requestFromInput(input, tt.sourceHeaderKeys) + view := requestViewFromInput(input) got := struct { RemoteAddr string Path string - Headers http.Header + Values []string }{ - RemoteAddr: req.RemoteAddr, - Headers: req.Header, - } - if req.URL != nil { - got.Path = req.URL.Path + RemoteAddr: view.remoteAddr(), + Path: view.path(), + Values: view.values(tt.headerName), } want := struct { RemoteAddr string Path string - Headers http.Header + Values []string }{ RemoteAddr: tt.wantRemoteAddr, Path: tt.wantPath, - Headers: tt.wantHeaders, + Values: tt.wantHeaderValues, } if diff := cmp.Diff(want, got, cmpopts.EquateEmpty()); diff != "" { diff --git a/ip_parse_test.go b/ip_parse_test.go deleted file mode 100644 index 47acf77..0000000 --- a/ip_parse_test.go +++ /dev/null @@ -1,385 +0,0 @@ -package clientip - -import ( - "net/netip" - "testing" -) - -func TestParseIP(t *testing.T) { - tests := []struct { - name string - input string - want netip.Addr - wantErr bool - }{ - { - name: "valid IPv4", - input: "203.0.113.1", - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "valid IPv4 with leading whitespace", - input: " 203.0.113.1", - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "valid IPv4 with trailing whitespace", - input: "203.0.113.1 ", - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "valid IPv4 with surrounding whitespace", - input: " 203.0.113.1 ", - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "valid IPv4 with tabs", - input: "\t203.0.113.1\t", - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "valid IPv4 with port", - input: "203.0.113.1:8080", - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "valid IPv4 with double quotes", - input: `"203.0.113.1"`, - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "valid IPv4 with single quotes", - input: "'203.0.113.1'", - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "valid IPv4 with quotes and port", - input: `"203.0.113.1:8080"`, - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "valid IPv6", - input: "2001:db8::1", - want: netip.MustParseAddr("2001:db8::1"), - }, - { - name: "valid IPv6 with brackets", - input: "[2001:db8::1]", - want: netip.MustParseAddr("2001:db8::1"), - }, - { - name: "valid IPv6 with brackets and port", - input: "[2001:db8::1]:8080", - want: netip.MustParseAddr("2001:db8::1"), - }, - { - name: "valid IPv6 with whitespace and brackets", - input: " [2001:db8::1] ", - want: netip.MustParseAddr("2001:db8::1"), - }, - { - name: "localhost IPv4", - input: "127.0.0.1", - want: netip.MustParseAddr("127.0.0.1"), - }, - { - name: "localhost IPv4 with port", - input: "127.0.0.1:8080", - want: netip.MustParseAddr("127.0.0.1"), - }, - { - name: "localhost IPv6", - input: "::1", - want: netip.MustParseAddr("::1"), - }, - { - name: "localhost IPv6 with brackets and port", - input: "[::1]:8080", - want: netip.MustParseAddr("::1"), - }, - { - name: "empty string", - input: "", - wantErr: true, - }, - { - name: "whitespace only", - input: " ", - wantErr: true, - }, - { - name: "quotes only", - input: `""`, - wantErr: true, - }, - { - name: "unmatched leading double quote", - input: `"203.0.113.1`, - wantErr: true, - }, - { - name: "unmatched trailing double quote", - input: `203.0.113.1"`, - wantErr: true, - }, - { - name: "unmatched leading single quote", - input: `'203.0.113.1`, - wantErr: true, - }, - { - name: "unmatched trailing single quote", - input: `203.0.113.1'`, - wantErr: true, - }, - { - name: "invalid IP", - input: "not-an-ip", - wantErr: true, - }, - { - name: "invalid IPv4", - input: "999.999.999.999", - wantErr: true, - }, - { - name: "port only", - input: ":8080", - wantErr: true, - }, - { - name: "brackets only", - input: "[]", - wantErr: true, - }, - { - name: "unmatched leading bracket", - input: "[2001:db8::1", - wantErr: true, - }, - { - name: "unmatched trailing bracket", - input: "2001:db8::1]", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := parseIP(tt.input) - if tt.wantErr { - if got.IsValid() { - t.Errorf("parseIP(%q) = %v, want invalid", tt.input, got) - } - } else { - if !got.IsValid() { - t.Errorf("parseIP(%q) = invalid, want %v", tt.input, tt.want) - return - } - if got != tt.want { - t.Errorf("parseIP(%q) = %v, want %v", tt.input, got, tt.want) - } - } - }) - } -} - -func TestParseRemoteAddr(t *testing.T) { - tests := []struct { - name string - input string - want netip.Addr - wantErr bool - }{ - { - name: "ipv4 host:port", - input: "203.0.113.1:8080", - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "ipv6 host:port", - input: "[2001:db8::1]:443", - want: netip.MustParseAddr("2001:db8::1"), - }, - { - name: "bare ipv4 fallback", - input: "203.0.113.1", - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "bare ipv6 fallback", - input: "2001:db8::1", - want: netip.MustParseAddr("2001:db8::1"), - }, - { - name: "bracketed ipv6 fallback", - input: "[2001:db8::1]", - want: netip.MustParseAddr("2001:db8::1"), - }, - { - name: "quoted ipv4 fallback", - input: `"203.0.113.1"`, - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "hostname with port", - input: "example.com:443", - wantErr: true, - }, - { - name: "non-numeric port is ignored", - input: "203.0.113.1:notaport", - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "empty", - input: "", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := parseRemoteAddr(tt.input) - if tt.wantErr { - if got.IsValid() { - t.Errorf("parseRemoteAddr(%q) = %v, want invalid", tt.input, got) - } - return - } - - if !got.IsValid() { - t.Errorf("parseRemoteAddr(%q) = invalid, want %v", tt.input, tt.want) - return - } - - if got != tt.want { - t.Errorf("parseRemoteAddr(%q) = %v, want %v", tt.input, got, tt.want) - } - }) - } -} - -func TestTrimMatchedChar(t *testing.T) { - tests := []struct { - name string - input string - ch byte - want string - }{ - { - name: "matching double quote delimiter", - input: `"203.0.113.1"`, - ch: '"', - want: "203.0.113.1", - }, - { - name: "matching single quote delimiter", - input: "'203.0.113.1'", - ch: '\'', - want: "203.0.113.1", - }, - { - name: "non-matching delimiter", - input: "203.0.113.1", - ch: '"', - want: "203.0.113.1", - }, - { - name: "too short to trim", - input: `"`, - ch: '"', - want: `"`, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := trimMatchedChar(tt.input, tt.ch) - if got != tt.want { - t.Errorf("trimMatchedChar(%q, %q) = %q, want %q", tt.input, tt.ch, got, tt.want) - } - }) - } -} - -func TestTrimMatchedPair(t *testing.T) { - tests := []struct { - name string - input string - start byte - end byte - want string - }{ - { - name: "matching pair", - input: "[2001:db8::1]", - start: '[', - end: ']', - want: "2001:db8::1", - }, - { - name: "unmatched leading bracket", - input: "[2001:db8::1", - start: '[', - end: ']', - want: "[2001:db8::1", - }, - { - name: "unmatched trailing bracket", - input: "2001:db8::1]", - start: '[', - end: ']', - want: "2001:db8::1]", - }, - { - name: "too short to trim", - input: "[", - start: '[', - end: ']', - want: "[", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := trimMatchedPair(tt.input, tt.start, tt.end) - if got != tt.want { - t.Errorf("trimMatchedPair(%q, %q, %q) = %q, want %q", tt.input, tt.start, tt.end, got, tt.want) - } - }) - } -} - -func TestNormalizeIP(t *testing.T) { - tests := []struct { - name string - input netip.Addr - want netip.Addr - }{ - { - name: "IPv4 - no change", - input: netip.MustParseAddr("203.0.113.1"), - want: netip.MustParseAddr("203.0.113.1"), - }, - { - name: "IPv6 - no change", - input: netip.MustParseAddr("2001:db8::1"), - want: netip.MustParseAddr("2001:db8::1"), - }, - { - name: "IPv4-mapped IPv6 - unmapped", - input: netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 203, 0, 113, 1}), - want: netip.MustParseAddr("203.0.113.1"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := normalizeIP(tt.input) - if got != tt.want { - t.Errorf("normalizeIP(%v) = %v, want %v", tt.input, got, tt.want) - } - }) - } -} diff --git a/ip_validation.go b/ip_validation.go deleted file mode 100644 index 82c2927..0000000 --- a/ip_validation.go +++ /dev/null @@ -1,88 +0,0 @@ -package clientip - -import "net/netip" - -var ( - reservedClientIPv4Prefixes = []netip.Prefix{ - mustParsePrefix("0.0.0.0/8"), - mustParsePrefix("100.64.0.0/10"), - mustParsePrefix("192.0.0.0/24"), - mustParsePrefix("192.0.2.0/24"), - mustParsePrefix("198.18.0.0/15"), - mustParsePrefix("198.51.100.0/24"), - mustParsePrefix("203.0.113.0/24"), - mustParsePrefix("240.0.0.0/4"), - } - - reservedClientIPv6Prefixes = []netip.Prefix{ - mustParsePrefix("64:ff9b::/96"), - mustParsePrefix("64:ff9b:1::/48"), - mustParsePrefix("100::/64"), - mustParsePrefix("2001:2::/48"), - mustParsePrefix("2001:db8::/32"), - mustParsePrefix("2001:20::/28"), - } -) - -func (e *Extractor) isPlausibleClientIP(ip netip.Addr) bool { - if !ip.IsValid() { - return false - } - - if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsMulticast() || ip.IsUnspecified() { - e.config.metrics.RecordSecurityEvent(securityEventInvalidIP) - return false - } - - // Check for reserved/special-use ranges that should never be client IPs - if isReservedIP(ip) && !e.isAllowlistedReservedClientIP(ip) { - e.config.metrics.RecordSecurityEvent(securityEventReservedIP) - return false - } - - if !e.config.allowPrivateIPs && ip.IsPrivate() { - e.config.metrics.RecordSecurityEvent(securityEventPrivateIP) - return false - } - - return true -} - -func (e *Extractor) isAllowlistedReservedClientIP(ip netip.Addr) bool { - if len(e.config.allowReservedClientPrefixes) == 0 || !ip.IsValid() { - return false - } - - ip = normalizeIP(ip) - - for _, prefix := range e.config.allowReservedClientPrefixes { - if prefix.Contains(ip) { - return true - } - } - - return false -} - -// isReservedIP checks if an IP is in a reserved or special-use range that -// should never appear as a real client IP address. -func isReservedIP(ip netip.Addr) bool { - if !ip.IsValid() { - return false - } - - ip = normalizeIP(ip) - - prefixes := reservedClientIPv6Prefixes - if ip.Is4() { - prefixes = reservedClientIPv4Prefixes - } - - for _, prefix := range prefixes { - if prefix.Contains(ip) { - return true - } - } - - return false -} diff --git a/logger.go b/logger.go deleted file mode 100644 index de4f773..0000000 --- a/logger.go +++ /dev/null @@ -1,25 +0,0 @@ -package clientip - -import ( - "context" -) - -// Logger records security-significant events emitted by Extractor. -// -// Implementations should be safe for concurrent use, as a single Extractor -// instance is typically shared across many goroutines. -// -// The provided context comes from the inbound HTTP request and can carry -// tracing metadata (for example, trace or span IDs). -// -// The interface intentionally mirrors slog's WarnContext signature, so -// *slog.Logger can be used directly without an adapter. -type Logger interface { - WarnContext(ctx context.Context, msg string, args ...any) -} - -// noopLogger is the default Logger implementation when logging is not -// explicitly configured. -type noopLogger struct{} - -func (noopLogger) WarnContext(context.Context, string, ...any) {} diff --git a/logger_test.go b/logger_test.go deleted file mode 100644 index 378bf2e..0000000 --- a/logger_test.go +++ /dev/null @@ -1,410 +0,0 @@ -package clientip - -import ( - "context" - "errors" - "net/netip" - "sync" - "testing" -) - -type loggerTestContextKey string - -type capturedLogEntry struct { - ctx context.Context - attrs map[string]any -} - -type capturedLogger struct { - mu sync.Mutex - entries []capturedLogEntry -} - -func (l *capturedLogger) WarnContext(ctx context.Context, msg string, args ...any) { - l.mu.Lock() - defer l.mu.Unlock() - - l.entries = append(l.entries, capturedLogEntry{ - ctx: ctx, - attrs: attrsToMap(args), - }) -} - -func (l *capturedLogger) snapshot() []capturedLogEntry { - l.mu.Lock() - defer l.mu.Unlock() - - entries := make([]capturedLogEntry, len(l.entries)) - copy(entries, l.entries) - return entries -} - -func attrsToMap(args []any) map[string]any { - attrs := make(map[string]any) - for i := 0; i+1 < len(args); i += 2 { - key, ok := args[i].(string) - if !ok { - continue - } - attrs[key] = args[i+1] - } - return attrs -} - -func assertAttr(t *testing.T, attrs map[string]any, key string, want any) { - t.Helper() - - got, ok := attrs[key] - if !ok { - t.Fatalf("missing %q attr", key) - } - - if got != want { - t.Fatalf("%s attr = %v, want %v", key, got, want) - } -} - -func assertCommonSecurityWarningAttrs(t *testing.T, attrs map[string]any, event string, source Source, path, remoteAddr string) { - t.Helper() - - assertAttr(t, attrs, "event", event) - assertAttr(t, attrs, "source", source.String()) - assertAttr(t, attrs, "path", path) - assertAttr(t, attrs, "remote_addr", remoteAddr) -} - -func TestLogging_MultipleXFFHeaders_DoNotWarn(t *testing.T) { - logger := &capturedLogger{} - - extractor, err := New( - WithLogger(logger), - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("1.1.1.1:8080", "/test/multiple-headers") - req.Header.Add("X-Forwarded-For", "8.8.8.8") - req.Header.Add("X-Forwarded-For", "9.9.9.9") - - result, err := extractor.Extract(req) - if err != nil || !result.IP.IsValid() { - t.Fatalf("expected extraction success, got error: %v", err) - } - if got, want := result.Source, SourceXForwardedFor; got != want { - t.Fatalf("source = %q, want %q", got, want) - } - if got, want := result.IP.String(), "9.9.9.9"; got != want { - t.Fatalf("ip = %q, want %q", got, want) - } - - entries := logger.snapshot() - if len(entries) != 0 { - t.Fatalf("logged entries = %d, want 0", len(entries)) - } -} - -func TestLogging_MultipleSingleIPHeaders_EmitsWarning(t *testing.T) { - logger := &capturedLogger{} - - extractor, err := New( - WithLogger(logger), - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceXRealIP), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("1.1.1.1:8080", "/test/multiple-single-ip-headers") - req.Header.Add("X-Real-IP", "8.8.8.8") - req.Header.Add("X-Real-IP", "9.9.9.9") - - result, err := extractor.Extract(req) - if err == nil && result.IP.IsValid() { - t.Fatal("expected extraction to fail for multiple single-IP headers") - } - if !errors.Is(err, ErrMultipleSingleIPHeaders) { - t.Fatalf("error = %v, want ErrMultipleSingleIPHeaders", err) - } - - entries := logger.snapshot() - if len(entries) != 1 { - t.Fatalf("logged entries = %d, want 1", len(entries)) - } - - entry := entries[0] - assertCommonSecurityWarningAttrs( - t, - entry.attrs, - securityEventMultipleHeaders, - SourceXRealIP, - "/test/multiple-single-ip-headers", - "1.1.1.1:8080", - ) - assertAttr(t, entry.attrs, "header", "X-Real-IP") - assertAttr(t, entry.attrs, "header_count", 2) -} - -func TestLogging_ChainTooLong_EmitsWarning(t *testing.T) { - logger := &capturedLogger{} - - extractor, err := New( - WithLogger(logger), - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceXForwardedFor), - WithMaxChainLength(2), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("1.1.1.1:8080", "/test/chain-too-long") - req.Header.Set("X-Forwarded-For", "8.8.8.8, 9.9.9.9, 4.4.4.4") - - result, err := extractor.Extract(req) - if err == nil && result.IP.IsValid() { - t.Fatal("expected extraction to fail for overlong X-Forwarded-For chain") - } - if !errors.Is(err, ErrChainTooLong) { - t.Fatalf("error = %v, want ErrChainTooLong", err) - } - - entries := logger.snapshot() - if len(entries) != 1 { - t.Fatalf("logged entries = %d, want 1", len(entries)) - } - - entry := entries[0] - assertCommonSecurityWarningAttrs( - t, - entry.attrs, - securityEventChainTooLong, - SourceXForwardedFor, - "/test/chain-too-long", - "1.1.1.1:8080", - ) - assertAttr(t, entry.attrs, "chain_length", 3) - assertAttr(t, entry.attrs, "max_length", 2) -} - -func TestLogging_TooFewTrustedProxies_EmitsWarning(t *testing.T) { - logger := &capturedLogger{} - cidrs := mustParseCIDRs(t, "10.0.0.0/8") - - extractor, err := New( - WithLogger(logger), - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(2), - WithMaxTrustedProxies(3), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("10.0.0.1:8080", "/test/proxy-count") - req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1") - - result, err := extractor.Extract(req) - if err == nil && result.IP.IsValid() { - t.Fatal("expected extraction to fail for too few trusted proxies") - } - if !errors.Is(err, ErrTooFewTrustedProxies) { - t.Fatalf("error = %v, want ErrTooFewTrustedProxies", err) - } - - entries := logger.snapshot() - if len(entries) != 1 { - t.Fatalf("logged entries = %d, want 1", len(entries)) - } - - entry := entries[0] - assertCommonSecurityWarningAttrs( - t, - entry.attrs, - securityEventTooFewTrustedProxies, - SourceXForwardedFor, - "/test/proxy-count", - "10.0.0.1:8080", - ) - assertAttr(t, entry.attrs, "trusted_proxy_count", 1) - assertAttr(t, entry.attrs, "min_trusted_proxies", 2) - assertAttr(t, entry.attrs, "max_trusted_proxies", 3) -} - -func TestLogging_NoTrustedProxies_EmitsWarning(t *testing.T) { - logger := &capturedLogger{} - cidrs := mustParseCIDRs(t, "10.0.0.0/8") - - extractor, err := New( - WithLogger(logger), - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(3), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("10.0.0.1:8080", "/test/no-trusted-proxies") - req.Header.Set("X-Forwarded-For", "1.1.1.1") - - result, err := extractor.Extract(req) - if err == nil && result.IP.IsValid() { - t.Fatal("expected extraction to fail when no trusted proxies are present in XFF") - } - if !errors.Is(err, ErrNoTrustedProxies) { - t.Fatalf("error = %v, want ErrNoTrustedProxies", err) - } - - entries := logger.snapshot() - if len(entries) != 1 { - t.Fatalf("logged entries = %d, want 1", len(entries)) - } - - entry := entries[0] - assertCommonSecurityWarningAttrs( - t, - entry.attrs, - securityEventNoTrustedProxies, - SourceXForwardedFor, - "/test/no-trusted-proxies", - "10.0.0.1:8080", - ) - assertAttr(t, entry.attrs, "trusted_proxy_count", 0) - assertAttr(t, entry.attrs, "min_trusted_proxies", 1) - assertAttr(t, entry.attrs, "max_trusted_proxies", 3) -} - -func TestLogging_TooManyTrustedProxies_EmitsWarning(t *testing.T) { - logger := &capturedLogger{} - cidrs := mustParseCIDRs(t, "10.0.0.0/8") - - extractor, err := New( - WithLogger(logger), - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(1), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("10.0.0.2:8080", "/test/too-many-proxies") - req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1, 10.0.0.2") - - result, err := extractor.Extract(req) - if err == nil && result.IP.IsValid() { - t.Fatal("expected extraction to fail for too many trusted proxies") - } - if !errors.Is(err, ErrTooManyTrustedProxies) { - t.Fatalf("error = %v, want ErrTooManyTrustedProxies", err) - } - - entries := logger.snapshot() - if len(entries) != 1 { - t.Fatalf("logged entries = %d, want 1", len(entries)) - } - - entry := entries[0] - assertCommonSecurityWarningAttrs( - t, - entry.attrs, - securityEventTooManyTrustedProxies, - SourceXForwardedFor, - "/test/too-many-proxies", - "10.0.0.2:8080", - ) - assertAttr(t, entry.attrs, "trusted_proxy_count", 2) - assertAttr(t, entry.attrs, "min_trusted_proxies", 1) - assertAttr(t, entry.attrs, "max_trusted_proxies", 1) -} - -func TestLogging_UntrustedProxy_EmitsWarning(t *testing.T) { - logger := &capturedLogger{} - cidrs := mustParseCIDRs(t, "10.0.0.0/8") - - extractor, err := New( - WithLogger(logger), - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(3), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("8.8.8.8:8080", "/test/untrusted-proxy") - req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1") - - result, err := extractor.Extract(req) - if err == nil && result.IP.IsValid() { - t.Fatal("expected extraction to fail for untrusted proxy") - } - if !errors.Is(err, ErrUntrustedProxy) { - t.Fatalf("error = %v, want ErrUntrustedProxy", err) - } - - entries := logger.snapshot() - if len(entries) != 1 { - t.Fatalf("logged entries = %d, want 1", len(entries)) - } - - entry := entries[0] - assertCommonSecurityWarningAttrs( - t, - entry.attrs, - securityEventUntrustedProxy, - SourceXForwardedFor, - "/test/untrusted-proxy", - "8.8.8.8:8080", - ) -} - -func TestLogging_MalformedForwarded_EmitsWarning(t *testing.T) { - logger := &capturedLogger{} - - extractor, err := New( - WithLogger(logger), - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceForwarded, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("1.1.1.1:8080", "/test/malformed-forwarded") - req.Header.Set("Forwarded", "for=\"1.1.1.1") - - result, err := extractor.Extract(req) - if err == nil && result.IP.IsValid() { - t.Fatal("expected extraction to fail closed on malformed Forwarded") - } - if !errors.Is(err, ErrInvalidForwardedHeader) { - t.Fatalf("error = %v, want ErrInvalidForwardedHeader", err) - } - if result.Source != SourceForwarded { - t.Fatalf("source = %q, want %q", result.Source, SourceForwarded) - } - - entries := logger.snapshot() - if len(entries) != 1 { - t.Fatalf("logged entries = %d, want 1", len(entries)) - } - - entry := entries[0] - assertCommonSecurityWarningAttrs( - t, - entry.attrs, - securityEventMalformedForwarded, - SourceForwarded, - "/test/malformed-forwarded", - "1.1.1.1:8080", - ) -} diff --git a/metrics.go b/metrics.go deleted file mode 100644 index 6ab9816..0000000 --- a/metrics.go +++ /dev/null @@ -1,28 +0,0 @@ -package clientip - -// Metrics records extraction outcomes and security events emitted by -// Extractor. -// -// Implementations should be safe for concurrent use, as a single Extractor -// instance is typically shared across many goroutines. -type Metrics interface { - // RecordExtractionSuccess is called when a source successfully returns a - // client IP. - RecordExtractionSuccess(source string) - // RecordExtractionFailure is called when a source is attempted but cannot - // return a valid client IP. - RecordExtractionFailure(source string) - // RecordSecurityEvent is called when the extractor observes a - // security-relevant condition. - RecordSecurityEvent(event string) -} - -// noopMetrics is the default Metrics implementation when metrics are not -// explicitly configured. -type noopMetrics struct{} - -func (noopMetrics) RecordExtractionSuccess(string) {} - -func (noopMetrics) RecordExtractionFailure(string) {} - -func (noopMetrics) RecordSecurityEvent(string) {} diff --git a/metrics_test.go b/metrics_test.go deleted file mode 100644 index d3c4d18..0000000 --- a/metrics_test.go +++ /dev/null @@ -1,503 +0,0 @@ -package clientip - -import ( - "net/netip" - "sync" - "testing" -) - -type mockMetrics struct { - mu sync.Mutex - successCount map[string]int - failureCount map[string]int - securityEvents map[string]int -} - -func newMockMetrics() *mockMetrics { - return &mockMetrics{ - successCount: make(map[string]int), - failureCount: make(map[string]int), - securityEvents: make(map[string]int), - } -} - -func (m *mockMetrics) RecordExtractionSuccess(source string) { - m.mu.Lock() - defer m.mu.Unlock() - m.successCount[source]++ -} - -func (m *mockMetrics) RecordExtractionFailure(source string) { - m.mu.Lock() - defer m.mu.Unlock() - m.failureCount[source]++ -} - -func (m *mockMetrics) RecordSecurityEvent(event string) { - m.mu.Lock() - defer m.mu.Unlock() - m.securityEvents[event]++ -} - -func (m *mockMetrics) getSuccessCount(source Source) int { - m.mu.Lock() - defer m.mu.Unlock() - return m.successCount[source.String()] -} - -func (m *mockMetrics) getFailureCount(source Source) int { - m.mu.Lock() - defer m.mu.Unlock() - return m.failureCount[source.String()] -} - -func (m *mockMetrics) getSecurityEventCount(event string) int { - m.mu.Lock() - defer m.mu.Unlock() - return m.securityEvents[event] -} - -func TestMetrics_ExtractionSuccess(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithMetrics(metrics), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("1.1.1.1:12345", "") - - result, err := extractor.Extract(req) - if err != nil || !result.IP.IsValid() { - t.Errorf("Extract() failed: %v", err) - } - - if got := metrics.getSuccessCount(SourceRemoteAddr); got != 1 { - t.Errorf("success count for %s = %d, want 1", SourceRemoteAddr, got) - } -} - -func TestMetrics_ExtractionFailure(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithMetrics(metrics), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("127.0.0.1:8080", "") - - result, err := extractor.Extract(req) - if err == nil && result.IP.IsValid() { - t.Errorf("Extract() should have failed for loopback IP") - } - - if got := metrics.getFailureCount(SourceRemoteAddr); got != 1 { - t.Errorf("failure count for %s = %d, want 1", SourceRemoteAddr, got) - } -} - -func TestMetrics_SecurityEvent_InvalidIP(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithMetrics(metrics), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("127.0.0.1:8080", "") - - _, _ = extractor.Extract(req) - - if got := metrics.getSecurityEventCount(securityEventInvalidIP); got != 1 { - t.Errorf("security event count for %s = %d, want 1", securityEventInvalidIP, got) - } -} - -func TestMetrics_SecurityEvent_PrivateIP(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithAllowPrivateIPs(false), - WithMetrics(metrics), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("192.168.1.1:8080", "") - - _, _ = extractor.Extract(req) - - if got := metrics.getSecurityEventCount(securityEventPrivateIP); got != 1 { - t.Errorf("security event count for %s = %d, want 1", securityEventPrivateIP, got) - } -} - -func TestMetrics_SecurityEvent_ReservedIP(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithMetrics(metrics), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("198.51.100.1:8080", "") - - _, _ = extractor.Extract(req) - - if got := metrics.getSecurityEventCount(securityEventReservedIP); got != 1 { - t.Errorf("security event count for %s = %d, want 1", securityEventReservedIP, got) - } -} - -func TestMetrics_SecurityEvent_ReservedIP_Allowlisted(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithAllowedReservedClientPrefixes(netip.MustParsePrefix("198.51.100.0/24")), - WithMetrics(metrics), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("198.51.100.1:8080", "") - - if _, err := extractor.Extract(req); err != nil { - t.Fatalf("Extract() error = %v", err) - } - - if got := metrics.getSecurityEventCount(securityEventReservedIP); got != 0 { - t.Errorf("security event count for %s = %d, want 0", securityEventReservedIP, got) - } - if got := metrics.getSuccessCount(SourceRemoteAddr); got != 1 { - t.Errorf("success count for %s = %d, want 1", SourceRemoteAddr, got) - } -} - -func TestMetrics_MultipleXFFHeaders_DoNotEmitSecurityEvent(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithMetrics(metrics), - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("1.1.1.1:8080", "") - req.Header.Add("X-Forwarded-For", "8.8.8.8") - req.Header.Add("X-Forwarded-For", "1.1.1.1") - - result, err := extractor.Extract(req) - if err != nil || !result.IP.IsValid() { - t.Fatalf("expected extraction success, got error: %v", err) - } - - if got := metrics.getSecurityEventCount(securityEventMultipleHeaders); got != 0 { - t.Errorf("security event count for %s = %d, want 0", securityEventMultipleHeaders, got) - } - - if got := metrics.getSuccessCount(SourceXForwardedFor); got != 1 { - t.Errorf("success count for %s = %d, want 1", SourceXForwardedFor, got) - } - - if got := metrics.getFailureCount(SourceXForwardedFor); got != 0 { - t.Errorf("failure count for %s = %d, want 0", SourceXForwardedFor, got) - } -} - -func TestMetrics_SecurityEvent_MultipleSingleIPHeaders(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithMetrics(metrics), - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceXRealIP), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("1.1.1.1:8080", "") - req.Header.Add("X-Real-IP", "8.8.8.8") - req.Header.Add("X-Real-IP", "9.9.9.9") - - result, err := extractor.Extract(req) - if err == nil && result.IP.IsValid() { - t.Fatal("expected extraction to fail for multiple single-IP headers") - } - - if got := metrics.getSecurityEventCount(securityEventMultipleHeaders); got != 1 { - t.Errorf("security event count for %s = %d, want 1", securityEventMultipleHeaders, got) - } - - if got := metrics.getFailureCount(SourceXRealIP); got != 1 { - t.Errorf("failure count for %s = %d, want 1", SourceXRealIP, got) - } -} - -func TestMetrics_SecurityEvent_TooFewTrustedProxies(t *testing.T) { - metrics := newMockMetrics() - cidrs := mustParseCIDRs(t, "10.0.0.0/8") - extractor, err := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(2), - WithMaxTrustedProxies(3), - WithMetrics(metrics), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("10.0.0.1:8080", "") - req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1") - - _, _ = extractor.Extract(req) - - if got := metrics.getSecurityEventCount(securityEventTooFewTrustedProxies); got != 1 { - t.Errorf("security event count for %s = %d, want 1", securityEventTooFewTrustedProxies, got) - } -} - -func TestMetrics_SecurityEvent_NoTrustedProxies(t *testing.T) { - metrics := newMockMetrics() - cidrs := mustParseCIDRs(t, "10.0.0.0/8") - extractor, err := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(3), - WithMetrics(metrics), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("10.0.0.1:8080", "") - req.Header.Set("X-Forwarded-For", "1.1.1.1") - - _, _ = extractor.Extract(req) - - if got := metrics.getSecurityEventCount(securityEventNoTrustedProxies); got != 1 { - t.Errorf("security event count for %s = %d, want 1", securityEventNoTrustedProxies, got) - } -} - -func TestMetrics_SecurityEvent_TooManyTrustedProxies(t *testing.T) { - metrics := newMockMetrics() - cidrs := mustParseCIDRs(t, "10.0.0.0/8") - extractor, err := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(1), - WithMetrics(metrics), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("10.0.0.2:8080", "") - req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1, 10.0.0.2") - - _, _ = extractor.Extract(req) - - if got := metrics.getSecurityEventCount(securityEventTooManyTrustedProxies); got != 1 { - t.Errorf("security event count for %s = %d, want 1", securityEventTooManyTrustedProxies, got) - } -} - -func TestMetrics_SecurityEvent_UntrustedProxy(t *testing.T) { - metrics := newMockMetrics() - cidrs := mustParseCIDRs(t, "10.0.0.0/8") - extractor, err := New( - WithTrustedProxyPrefixes(cidrs...), - WithMinTrustedProxies(1), - WithMaxTrustedProxies(3), - WithMetrics(metrics), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("8.8.8.8:8080", "") - req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1") - - _, _ = extractor.Extract(req) - - if got := metrics.getSecurityEventCount(securityEventUntrustedProxy); got != 1 { - t.Errorf("security event count for %s = %d, want 1", securityEventUntrustedProxy, got) - } -} - -func TestMetrics_SecurityEvent_ChainTooLong(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithMaxChainLength(5), - WithMetrics(metrics), - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("127.0.0.1:8080", "") - req.Header.Set("X-Forwarded-For", "1.1.1.1, 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5, 6.6.6.6") - - _, _ = extractor.Extract(req) - - if got := metrics.getSecurityEventCount(securityEventChainTooLong); got != 1 { - t.Errorf("security event count for %s = %d, want 1", securityEventChainTooLong, got) - } -} - -func TestMetrics_SecurityEvent_MalformedForwarded(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithMetrics(metrics), - WithTrustedProxyAddrs(netip.MustParseAddr("1.1.1.1")), - WithSourcePriority(SourceForwarded, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("1.1.1.1:8080", "") - req.Header.Set("Forwarded", "for=\"1.1.1.1") - - _, _ = extractor.Extract(req) - - if got := metrics.getSecurityEventCount(securityEventMalformedForwarded); got != 1 { - t.Errorf("security event count for %s = %d, want 1", securityEventMalformedForwarded, got) - } -} - -func TestMetrics_ForwardedSourceSuccess(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithMetrics(metrics), - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceForwarded), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := newTestRequest("127.0.0.1:8080", "") - req.Header.Set("Forwarded", "for=1.1.1.1") - - result, err := extractor.Extract(req) - if err != nil || !result.IP.IsValid() { - t.Fatalf("Extract() failed: %v", err) - } - - if got := metrics.getSuccessCount(SourceForwarded); got != 1 { - t.Errorf("success count for %s = %d, want 1", SourceForwarded, got) - } -} - -func TestMetrics_MultipleExtractions(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithMetrics(metrics), - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - // Successful extraction - req1 := newTestRequest("1.1.1.1:12345", "") - _, _ = extractor.Extract(req1) - - // Another successful extraction - req2 := newTestRequest("8.8.8.8:8080", "") - _, _ = extractor.Extract(req2) - - // Failed extraction - req3 := newTestRequest("127.0.0.1:8080", "") - _, _ = extractor.Extract(req3) - - if got := metrics.getSuccessCount(SourceRemoteAddr); got != 2 { - t.Errorf("success count = %d, want 2", got) - } - - if got := metrics.getFailureCount(SourceRemoteAddr); got != 1 { - t.Errorf("failure count = %d, want 1", got) - } -} - -func TestMetrics_DifferentSources(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithMetrics(metrics), - WithTrustedLoopbackProxy(), - WithSourcePriority(SourceXForwardedFor, SourceRemoteAddr), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - // Success from X-Forwarded-For - req1 := newTestRequest("127.0.0.1:8080", "") - req1.Header.Set("X-Forwarded-For", "1.1.1.1") - _, _ = extractor.Extract(req1) - - // Success from RemoteAddr - req2 := newTestRequest("8.8.8.8:8080", "") - _, _ = extractor.Extract(req2) - - if got := metrics.getSuccessCount(SourceXForwardedFor); got != 1 { - t.Errorf("XFF success count = %d, want 1", got) - } - - if got := metrics.getSuccessCount(SourceRemoteAddr); got != 1 { - t.Errorf("RemoteAddr success count = %d, want 1", got) - } -} - -func TestMetrics_ConcurrentAccess(t *testing.T) { - metrics := newMockMetrics() - extractor, err := New( - WithMetrics(metrics), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - const goroutines = 50 - done := make(chan bool, goroutines) - - for i := 0; i < goroutines; i++ { - go func() { - req := newTestRequest("1.1.1.1:12345", "") - _, _ = extractor.Extract(req) - done <- true - }() - } - - for i := 0; i < goroutines; i++ { - <-done - } - - if got := metrics.getSuccessCount(SourceRemoteAddr); got != goroutines { - t.Errorf("success count = %d, want %d", got, goroutines) - } -} - -func TestNoopMetrics(t *testing.T) { - noop := noopMetrics{} - - // Should not panic - noop.RecordExtractionSuccess("test") - noop.RecordExtractionFailure("test") - noop.RecordSecurityEvent("test") -} diff --git a/observability.go b/observability.go new file mode 100644 index 0000000..bcf6f46 --- /dev/null +++ b/observability.go @@ -0,0 +1,72 @@ +package clientip + +import "context" + +// SecurityEvent... constants are stable public labels for extractor security +// events. Metrics adapters and log consumers can depend on these names. +const ( + SecurityEventMultipleHeaders = "multiple_headers" + SecurityEventChainTooLong = "chain_too_long" + SecurityEventUntrustedProxy = "untrusted_proxy" + SecurityEventNoTrustedProxies = "no_trusted_proxies" + SecurityEventTooFewTrustedProxies = "too_few_trusted_proxies" + SecurityEventTooManyTrustedProxies = "too_many_trusted_proxies" + SecurityEventInvalidIP = "invalid_ip" + SecurityEventReservedIP = "reserved_ip" + SecurityEventPrivateIP = "private_ip" + SecurityEventMalformedForwarded = "malformed_forwarded" +) + +// Logger records security-significant events emitted by Extractor. +// +// Implementations should be safe for concurrent use, as a single Extractor +// instance is typically shared across many goroutines. +// +// The provided context comes from the inbound HTTP request and can carry +// tracing metadata (for example, trace or span IDs). +// +// Resolver preferred fallback does not emit separate log events in this phase. +// Inspect Resolution.FallbackUsed when that distinction matters. +// +// The interface intentionally mirrors slog's WarnContext signature, so +// *slog.Logger can be used directly without an adapter. +type Logger interface { + WarnContext(ctx context.Context, msg string, args ...any) +} + +// noopLogger is the default Logger implementation when logging is not +// explicitly configured. +type noopLogger struct{} + +func (noopLogger) WarnContext(context.Context, string, ...any) {} + +// Metrics records extraction outcomes and security events emitted by +// Extractor. +// +// Implementations should be safe for concurrent use, as a single Extractor +// instance is typically shared across many goroutines. +// +// Security event labels are the exported SecurityEvent... constants. +// Resolver preferred fallback does not emit separate metrics in this phase; +// inspect Resolution.FallbackUsed when that distinction matters. +type Metrics interface { + // RecordExtractionSuccess is called when a source successfully returns a + // client IP. + RecordExtractionSuccess(source string) + // RecordExtractionFailure is called when a source is attempted but cannot + // return a valid client IP. + RecordExtractionFailure(source string) + // RecordSecurityEvent is called when the extractor observes a + // security-relevant condition. + RecordSecurityEvent(event string) +} + +// noopMetrics is the default Metrics implementation when metrics are not +// explicitly configured. +type noopMetrics struct{} + +func (noopMetrics) RecordExtractionSuccess(string) {} + +func (noopMetrics) RecordExtractionFailure(string) {} + +func (noopMetrics) RecordSecurityEvent(string) {} diff --git a/observability_test.go b/observability_test.go new file mode 100644 index 0000000..0cee6a9 --- /dev/null +++ b/observability_test.go @@ -0,0 +1,786 @@ +package clientip + +import ( + "context" + "errors" + "net/netip" + "sync" + "testing" +) + +type loggerTestContextKey string + +type capturedLogEntry struct { + ctx context.Context + attrs map[string]any +} + +type capturedLogger struct { + mu sync.Mutex + entries []capturedLogEntry +} + +func (l *capturedLogger) WarnContext(ctx context.Context, msg string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + + l.entries = append(l.entries, capturedLogEntry{ + ctx: ctx, + attrs: attrsToMap(args), + }) +} + +func (l *capturedLogger) snapshot() []capturedLogEntry { + l.mu.Lock() + defer l.mu.Unlock() + + entries := make([]capturedLogEntry, len(l.entries)) + copy(entries, l.entries) + return entries +} + +func attrsToMap(args []any) map[string]any { + attrs := make(map[string]any) + for i := 0; i+1 < len(args); i += 2 { + key, ok := args[i].(string) + if !ok { + continue + } + attrs[key] = args[i+1] + } + return attrs +} + +func assertAttr(t *testing.T, attrs map[string]any, key string, want any) { + t.Helper() + + got, ok := attrs[key] + if !ok { + t.Fatalf("missing %q attr", key) + } + + if got != want { + t.Fatalf("%s attr = %v, want %v", key, got, want) + } +} + +func assertCommonSecurityWarningAttrs(t *testing.T, attrs map[string]any, event string, source Source, path, remoteAddr string) { + t.Helper() + + assertAttr(t, attrs, "event", event) + assertAttr(t, attrs, "source", source.String()) + assertAttr(t, attrs, "path", path) + assertAttr(t, attrs, "remote_addr", remoteAddr) +} + +type mockMetrics struct { + mu sync.Mutex + successCount map[string]int + failureCount map[string]int + securityEvents map[string]int +} + +func newMockMetrics() *mockMetrics { + return &mockMetrics{ + successCount: make(map[string]int), + failureCount: make(map[string]int), + securityEvents: make(map[string]int), + } +} + +func (m *mockMetrics) RecordExtractionSuccess(source string) { + m.mu.Lock() + defer m.mu.Unlock() + m.successCount[source]++ +} + +func (m *mockMetrics) RecordExtractionFailure(source string) { + m.mu.Lock() + defer m.mu.Unlock() + m.failureCount[source]++ +} + +func (m *mockMetrics) RecordSecurityEvent(event string) { + m.mu.Lock() + defer m.mu.Unlock() + m.securityEvents[event]++ +} + +func (m *mockMetrics) getSuccessCount(source Source) int { + m.mu.Lock() + defer m.mu.Unlock() + return m.successCount[source.String()] +} + +func (m *mockMetrics) getFailureCount(source Source) int { + m.mu.Lock() + defer m.mu.Unlock() + return m.failureCount[source.String()] +} + +func (m *mockMetrics) getSecurityEventCount(event string) int { + m.mu.Lock() + defer m.mu.Unlock() + return m.securityEvents[event] +} + +func TestLogging_MultipleXFFHeaders_DoNotWarn(t *testing.T) { + logger := &capturedLogger{} + + cfg := DefaultConfig() + cfg.Logger = logger + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("1.1.1.1:8080", "/test/multiple-headers") + req.Header.Add("X-Forwarded-For", "8.8.8.8") + req.Header.Add("X-Forwarded-For", "9.9.9.9") + + result, err := extractor.Extract(req) + if err != nil || !result.IP.IsValid() { + t.Fatalf("expected extraction success, got error: %v", err) + } + if got, want := result.Source, SourceXForwardedFor; got != want { + t.Fatalf("source = %q, want %q", got, want) + } + if got, want := result.IP.String(), "9.9.9.9"; got != want { + t.Fatalf("ip = %q, want %q", got, want) + } + + entries := logger.snapshot() + if len(entries) != 0 { + t.Fatalf("logged entries = %d, want 0", len(entries)) + } +} + +func TestLogging_MultipleSingleIPHeaders_EmitsWarning(t *testing.T) { + logger := &capturedLogger{} + + cfg := DefaultConfig() + cfg.Logger = logger + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceXRealIP} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("1.1.1.1:8080", "/test/multiple-single-ip-headers") + req.Header.Add("X-Real-IP", "8.8.8.8") + req.Header.Add("X-Real-IP", "9.9.9.9") + + result, err := extractor.Extract(req) + if err == nil && result.IP.IsValid() { + t.Fatal("expected extraction to fail for multiple single-IP headers") + } + if !errors.Is(err, ErrMultipleSingleIPHeaders) { + t.Fatalf("error = %v, want ErrMultipleSingleIPHeaders", err) + } + + entries := logger.snapshot() + if len(entries) != 1 { + t.Fatalf("logged entries = %d, want 1", len(entries)) + } + + entry := entries[0] + assertCommonSecurityWarningAttrs(t, entry.attrs, SecurityEventMultipleHeaders, SourceXRealIP, "/test/multiple-single-ip-headers", "1.1.1.1:8080") + assertAttr(t, entry.attrs, "header", "X-Real-Ip") + assertAttr(t, entry.attrs, "header_count", 2) +} + +func TestLogging_ChainTooLong_EmitsWarning(t *testing.T) { + logger := &capturedLogger{} + + cfg := DefaultConfig() + cfg.Logger = logger + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceXForwardedFor} + cfg.MaxChainLength = 2 + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("1.1.1.1:8080", "/test/chain-too-long") + req.Header.Set("X-Forwarded-For", "8.8.8.8, 9.9.9.9, 4.4.4.4") + + result, err := extractor.Extract(req) + if err == nil && result.IP.IsValid() { + t.Fatal("expected extraction to fail for overlong X-Forwarded-For chain") + } + if !errors.Is(err, ErrChainTooLong) { + t.Fatalf("error = %v, want ErrChainTooLong", err) + } + + entries := logger.snapshot() + if len(entries) != 1 { + t.Fatalf("logged entries = %d, want 1", len(entries)) + } + + entry := entries[0] + assertCommonSecurityWarningAttrs(t, entry.attrs, SecurityEventChainTooLong, SourceXForwardedFor, "/test/chain-too-long", "1.1.1.1:8080") + assertAttr(t, entry.attrs, "chain_length", 3) + assertAttr(t, entry.attrs, "max_length", 2) +} + +func TestLogging_TooFewTrustedProxies_EmitsWarning(t *testing.T) { + logger := &capturedLogger{} + cidrs := mustParseCIDRs(t, "10.0.0.0/8") + + cfg := DefaultConfig() + cfg.Logger = logger + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 2 + cfg.MaxTrustedProxies = 3 + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("10.0.0.1:8080", "/test/proxy-count") + req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1") + + result, err := extractor.Extract(req) + if err == nil && result.IP.IsValid() { + t.Fatal("expected extraction to fail for too few trusted proxies") + } + if !errors.Is(err, ErrTooFewTrustedProxies) { + t.Fatalf("error = %v, want ErrTooFewTrustedProxies", err) + } + + entries := logger.snapshot() + if len(entries) != 1 { + t.Fatalf("logged entries = %d, want 1", len(entries)) + } + + entry := entries[0] + assertCommonSecurityWarningAttrs(t, entry.attrs, SecurityEventTooFewTrustedProxies, SourceXForwardedFor, "/test/proxy-count", "10.0.0.1:8080") + assertAttr(t, entry.attrs, "trusted_proxy_count", 1) + assertAttr(t, entry.attrs, "min_trusted_proxies", 2) + assertAttr(t, entry.attrs, "max_trusted_proxies", 3) +} + +func TestLogging_NoTrustedProxies_EmitsWarning(t *testing.T) { + logger := &capturedLogger{} + cidrs := mustParseCIDRs(t, "10.0.0.0/8") + + cfg := DefaultConfig() + cfg.Logger = logger + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 3 + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("10.0.0.1:8080", "/test/no-trusted-proxies") + req.Header.Set("X-Forwarded-For", "1.1.1.1") + + result, err := extractor.Extract(req) + if err == nil && result.IP.IsValid() { + t.Fatal("expected extraction to fail when no trusted proxies are present in XFF") + } + if !errors.Is(err, ErrNoTrustedProxies) { + t.Fatalf("error = %v, want ErrNoTrustedProxies", err) + } + + entries := logger.snapshot() + if len(entries) != 1 { + t.Fatalf("logged entries = %d, want 1", len(entries)) + } + + entry := entries[0] + assertCommonSecurityWarningAttrs(t, entry.attrs, SecurityEventNoTrustedProxies, SourceXForwardedFor, "/test/no-trusted-proxies", "10.0.0.1:8080") + assertAttr(t, entry.attrs, "trusted_proxy_count", 0) + assertAttr(t, entry.attrs, "min_trusted_proxies", 1) + assertAttr(t, entry.attrs, "max_trusted_proxies", 3) +} + +func TestLogging_TooManyTrustedProxies_EmitsWarning(t *testing.T) { + logger := &capturedLogger{} + cidrs := mustParseCIDRs(t, "10.0.0.0/8") + + cfg := DefaultConfig() + cfg.Logger = logger + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 1 + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("10.0.0.2:8080", "/test/too-many-proxies") + req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1, 10.0.0.2") + + result, err := extractor.Extract(req) + if err == nil && result.IP.IsValid() { + t.Fatal("expected extraction to fail for too many trusted proxies") + } + if !errors.Is(err, ErrTooManyTrustedProxies) { + t.Fatalf("error = %v, want ErrTooManyTrustedProxies", err) + } + + entries := logger.snapshot() + if len(entries) != 1 { + t.Fatalf("logged entries = %d, want 1", len(entries)) + } + + entry := entries[0] + assertCommonSecurityWarningAttrs(t, entry.attrs, SecurityEventTooManyTrustedProxies, SourceXForwardedFor, "/test/too-many-proxies", "10.0.0.2:8080") + assertAttr(t, entry.attrs, "trusted_proxy_count", 2) + assertAttr(t, entry.attrs, "min_trusted_proxies", 1) + assertAttr(t, entry.attrs, "max_trusted_proxies", 1) +} + +func TestLogging_UntrustedProxy_EmitsWarning(t *testing.T) { + logger := &capturedLogger{} + cidrs := mustParseCIDRs(t, "10.0.0.0/8") + + cfg := DefaultConfig() + cfg.Logger = logger + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 3 + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("8.8.8.8:8080", "/test/untrusted-proxy") + req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1") + + result, err := extractor.Extract(req) + if err == nil && result.IP.IsValid() { + t.Fatal("expected extraction to fail for untrusted proxy") + } + if !errors.Is(err, ErrUntrustedProxy) { + t.Fatalf("error = %v, want ErrUntrustedProxy", err) + } + + entries := logger.snapshot() + if len(entries) != 1 { + t.Fatalf("logged entries = %d, want 1", len(entries)) + } + + entry := entries[0] + assertCommonSecurityWarningAttrs(t, entry.attrs, SecurityEventUntrustedProxy, SourceXForwardedFor, "/test/untrusted-proxy", "8.8.8.8:8080") +} + +func TestLogging_MalformedForwarded_EmitsWarning(t *testing.T) { + logger := &capturedLogger{} + + cfg := DefaultConfig() + cfg.Logger = logger + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceForwarded, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("1.1.1.1:8080", "/test/malformed-forwarded") + req.Header.Set("Forwarded", "for=\"1.1.1.1") + + result, err := extractor.Extract(req) + if err == nil && result.IP.IsValid() { + t.Fatal("expected extraction to fail closed on malformed Forwarded") + } + if !errors.Is(err, ErrInvalidForwardedHeader) { + t.Fatalf("error = %v, want ErrInvalidForwardedHeader", err) + } + if result.Source != SourceForwarded { + t.Fatalf("source = %q, want %q", result.Source, SourceForwarded) + } + + entries := logger.snapshot() + if len(entries) != 1 { + t.Fatalf("logged entries = %d, want 1", len(entries)) + } + + entry := entries[0] + assertCommonSecurityWarningAttrs(t, entry.attrs, SecurityEventMalformedForwarded, SourceForwarded, "/test/malformed-forwarded", "1.1.1.1:8080") +} + +func TestExtractInput_UsesInputContextAndPathInLogs(t *testing.T) { + logger := &capturedLogger{} + + cfg := DefaultConfig() + cfg.Logger = logger + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceXForwardedFor} + cfg.MaxChainLength = 1 + extractor := mustNewExtractor(t, cfg) + + ctx := context.WithValue(context.Background(), loggerTestContextKey("trace_id"), "trace-from-input") + headers := HeaderValuesFunc(func(name string) []string { + if name == "X-Forwarded-For" { + return []string{"8.8.8.8", "9.9.9.9"} + } + return nil + }) + + result, err := extractor.ExtractInput(Input{ + Context: ctx, + RemoteAddr: "1.1.1.1:8080", + Path: "/from-input", + Headers: headers, + }) + if err == nil && result.IP.IsValid() { + t.Fatal("expected extraction failure for overlong X-Forwarded-For chain") + } + if !errors.Is(err, ErrChainTooLong) { + t.Fatalf("error = %v, want ErrChainTooLong", err) + } + + entries := logger.snapshot() + if len(entries) != 1 { + t.Fatalf("logged entries = %d, want 1", len(entries)) + } + + entry := entries[0] + if got := entry.ctx.Value(loggerTestContextKey("trace_id")); got != "trace-from-input" { + t.Fatalf("trace context value = %v, want %q", got, "trace-from-input") + } + + assertCommonSecurityWarningAttrs(t, entry.attrs, SecurityEventChainTooLong, SourceXForwardedFor, "/from-input", "1.1.1.1:8080") +} + +func TestMetrics_ExtractionSuccess(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.Metrics = metrics + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("1.1.1.1:12345", "") + result, err := extractor.Extract(req) + if err != nil || !result.IP.IsValid() { + t.Errorf("Extract() failed: %v", err) + } + + if got := metrics.getSuccessCount(SourceRemoteAddr); got != 1 { + t.Errorf("success count for %s = %d, want 1", SourceRemoteAddr, got) + } +} + +func TestMetrics_ExtractionFailure(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.Metrics = metrics + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("127.0.0.1:8080", "") + result, err := extractor.Extract(req) + if err == nil && result.IP.IsValid() { + t.Errorf("Extract() should have failed for loopback IP") + } + + if got := metrics.getFailureCount(SourceRemoteAddr); got != 1 { + t.Errorf("failure count for %s = %d, want 1", SourceRemoteAddr, got) + } +} + +func TestMetrics_SecurityEvent_InvalidIP(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.Metrics = metrics + extractor := mustNewExtractor(t, cfg) + + _, _ = extractor.Extract(newTestRequest("127.0.0.1:8080", "")) + + if got := metrics.getSecurityEventCount(SecurityEventInvalidIP); got != 1 { + t.Errorf("security event count for %s = %d, want 1", SecurityEventInvalidIP, got) + } +} + +func TestMetrics_SecurityEvent_PrivateIP(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.AllowPrivateIPs = false + cfg.Metrics = metrics + extractor := mustNewExtractor(t, cfg) + + _, _ = extractor.Extract(newTestRequest("192.168.1.1:8080", "")) + + if got := metrics.getSecurityEventCount(SecurityEventPrivateIP); got != 1 { + t.Errorf("security event count for %s = %d, want 1", SecurityEventPrivateIP, got) + } +} + +func TestMetrics_SecurityEvent_ReservedIP(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.Metrics = metrics + extractor := mustNewExtractor(t, cfg) + + _, _ = extractor.Extract(newTestRequest("198.51.100.1:8080", "")) + + if got := metrics.getSecurityEventCount(SecurityEventReservedIP); got != 1 { + t.Errorf("security event count for %s = %d, want 1", SecurityEventReservedIP, got) + } +} + +func TestMetrics_SecurityEvent_ReservedIP_Allowlisted(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.AllowedReservedClientPrefixes = []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")} + cfg.Metrics = metrics + extractor := mustNewExtractor(t, cfg) + + if _, err := extractor.Extract(newTestRequest("198.51.100.1:8080", "")); err != nil { + t.Fatalf("Extract() error = %v", err) + } + + if got := metrics.getSecurityEventCount(SecurityEventReservedIP); got != 0 { + t.Errorf("security event count for %s = %d, want 0", SecurityEventReservedIP, got) + } + if got := metrics.getSuccessCount(SourceRemoteAddr); got != 1 { + t.Errorf("success count for %s = %d, want 1", SourceRemoteAddr, got) + } +} + +func TestMetrics_MultipleXFFHeaders_DoNotEmitSecurityEvent(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.Metrics = metrics + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("1.1.1.1:8080", "") + req.Header.Add("X-Forwarded-For", "8.8.8.8") + req.Header.Add("X-Forwarded-For", "1.1.1.1") + + result, err := extractor.Extract(req) + if err != nil || !result.IP.IsValid() { + t.Fatalf("expected extraction success, got error: %v", err) + } + + if got := metrics.getSecurityEventCount(SecurityEventMultipleHeaders); got != 0 { + t.Errorf("security event count for %s = %d, want 0", SecurityEventMultipleHeaders, got) + } + if got := metrics.getSuccessCount(SourceXForwardedFor); got != 1 { + t.Errorf("success count for %s = %d, want 1", SourceXForwardedFor, got) + } + if got := metrics.getFailureCount(SourceXForwardedFor); got != 0 { + t.Errorf("failure count for %s = %d, want 0", SourceXForwardedFor, got) + } +} + +func TestMetrics_SecurityEvent_MultipleSingleIPHeaders(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.Metrics = metrics + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceXRealIP} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("1.1.1.1:8080", "") + req.Header.Add("X-Real-IP", "8.8.8.8") + req.Header.Add("X-Real-IP", "9.9.9.9") + + result, err := extractor.Extract(req) + if err == nil && result.IP.IsValid() { + t.Fatal("expected extraction to fail for multiple single-IP headers") + } + + if got := metrics.getSecurityEventCount(SecurityEventMultipleHeaders); got != 1 { + t.Errorf("security event count for %s = %d, want 1", SecurityEventMultipleHeaders, got) + } + if got := metrics.getFailureCount(SourceXRealIP); got != 1 { + t.Errorf("failure count for %s = %d, want 1", SourceXRealIP, got) + } +} + +func TestMetrics_SecurityEvent_TooFewTrustedProxies(t *testing.T) { + metrics := newMockMetrics() + cidrs := mustParseCIDRs(t, "10.0.0.0/8") + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 2 + cfg.MaxTrustedProxies = 3 + cfg.Metrics = metrics + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("10.0.0.1:8080", "") + req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1") + _, _ = extractor.Extract(req) + + if got := metrics.getSecurityEventCount(SecurityEventTooFewTrustedProxies); got != 1 { + t.Errorf("security event count for %s = %d, want 1", SecurityEventTooFewTrustedProxies, got) + } +} + +func TestMetrics_SecurityEvent_NoTrustedProxies(t *testing.T) { + metrics := newMockMetrics() + cidrs := mustParseCIDRs(t, "10.0.0.0/8") + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 3 + cfg.Metrics = metrics + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("10.0.0.1:8080", "") + req.Header.Set("X-Forwarded-For", "1.1.1.1") + _, _ = extractor.Extract(req) + + if got := metrics.getSecurityEventCount(SecurityEventNoTrustedProxies); got != 1 { + t.Errorf("security event count for %s = %d, want 1", SecurityEventNoTrustedProxies, got) + } +} + +func TestMetrics_SecurityEvent_TooManyTrustedProxies(t *testing.T) { + metrics := newMockMetrics() + cidrs := mustParseCIDRs(t, "10.0.0.0/8") + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 1 + cfg.Metrics = metrics + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("10.0.0.2:8080", "") + req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1, 10.0.0.2") + _, _ = extractor.Extract(req) + + if got := metrics.getSecurityEventCount(SecurityEventTooManyTrustedProxies); got != 1 { + t.Errorf("security event count for %s = %d, want 1", SecurityEventTooManyTrustedProxies, got) + } +} + +func TestMetrics_SecurityEvent_UntrustedProxy(t *testing.T) { + metrics := newMockMetrics() + cidrs := mustParseCIDRs(t, "10.0.0.0/8") + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = cidrs + cfg.MinTrustedProxies = 1 + cfg.MaxTrustedProxies = 3 + cfg.Metrics = metrics + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("8.8.8.8:8080", "") + req.Header.Set("X-Forwarded-For", "1.1.1.1, 10.0.0.1") + _, _ = extractor.Extract(req) + + if got := metrics.getSecurityEventCount(SecurityEventUntrustedProxy); got != 1 { + t.Errorf("security event count for %s = %d, want 1", SecurityEventUntrustedProxy, got) + } +} + +func TestMetrics_SecurityEvent_ChainTooLong(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.MaxChainLength = 5 + cfg.Metrics = metrics + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXForwardedFor} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("127.0.0.1:8080", "") + req.Header.Set("X-Forwarded-For", "1.1.1.1, 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5, 6.6.6.6") + _, _ = extractor.Extract(req) + + if got := metrics.getSecurityEventCount(SecurityEventChainTooLong); got != 1 { + t.Errorf("security event count for %s = %d, want 1", SecurityEventChainTooLong, got) + } +} + +func TestMetrics_SecurityEvent_MalformedForwarded(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.Metrics = metrics + cfg.TrustedProxyPrefixes = mustProxyPrefixesFromAddrs(t, netip.MustParseAddr("1.1.1.1")) + cfg.Sources = []Source{SourceForwarded, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("1.1.1.1:8080", "") + req.Header.Set("Forwarded", "for=\"1.1.1.1") + _, _ = extractor.Extract(req) + + if got := metrics.getSecurityEventCount(SecurityEventMalformedForwarded); got != 1 { + t.Errorf("security event count for %s = %d, want 1", SecurityEventMalformedForwarded, got) + } +} + +func TestMetrics_ForwardedSourceSuccess(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.Metrics = metrics + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceForwarded} + extractor := mustNewExtractor(t, cfg) + + req := newTestRequest("127.0.0.1:8080", "") + req.Header.Set("Forwarded", "for=1.1.1.1") + result, err := extractor.Extract(req) + if err != nil || !result.IP.IsValid() { + t.Fatalf("Extract() failed: %v", err) + } + + if got := metrics.getSuccessCount(SourceForwarded); got != 1 { + t.Errorf("success count for %s = %d, want 1", SourceForwarded, got) + } +} + +func TestMetrics_MultipleExtractions(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.Metrics = metrics + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) + + _, _ = extractor.Extract(newTestRequest("1.1.1.1:12345", "")) + _, _ = extractor.Extract(newTestRequest("8.8.8.8:8080", "")) + _, _ = extractor.Extract(newTestRequest("127.0.0.1:8080", "")) + + if got := metrics.getSuccessCount(SourceRemoteAddr); got != 2 { + t.Errorf("success count = %d, want 2", got) + } + if got := metrics.getFailureCount(SourceRemoteAddr); got != 1 { + t.Errorf("failure count = %d, want 1", got) + } +} + +func TestMetrics_DifferentSources(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.Metrics = metrics + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXForwardedFor, SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) + + req1 := newTestRequest("127.0.0.1:8080", "") + req1.Header.Set("X-Forwarded-For", "1.1.1.1") + _, _ = extractor.Extract(req1) + _, _ = extractor.Extract(newTestRequest("8.8.8.8:8080", "")) + + if got := metrics.getSuccessCount(SourceXForwardedFor); got != 1 { + t.Errorf("XFF success count = %d, want 1", got) + } + if got := metrics.getSuccessCount(SourceRemoteAddr); got != 1 { + t.Errorf("RemoteAddr success count = %d, want 1", got) + } +} + +func TestMetrics_ConcurrentAccess(t *testing.T) { + metrics := newMockMetrics() + cfg := DefaultConfig() + cfg.Metrics = metrics + extractor := mustNewExtractor(t, cfg) + + const goroutines = 50 + done := make(chan bool, goroutines) + + for i := 0; i < goroutines; i++ { + go func() { + _, _ = extractor.Extract(newTestRequest("1.1.1.1:12345", "")) + done <- true + }() + } + + for i := 0; i < goroutines; i++ { + <-done + } + + if got := metrics.getSuccessCount(SourceRemoteAddr); got != goroutines { + t.Errorf("success count = %d, want %d", got, goroutines) + } +} + +func TestNoopMetrics(t *testing.T) { + noop := noopMetrics{} + noop.RecordExtractionSuccess("test") + noop.RecordExtractionFailure("test") + noop.RecordSecurityEvent("test") +} diff --git a/options.go b/options.go deleted file mode 100644 index c7dabcc..0000000 --- a/options.go +++ /dev/null @@ -1,203 +0,0 @@ -package clientip - -import ( - "fmt" - "net/netip" - "reflect" - "slices" -) - -func isNilValue(v any) bool { - if v == nil { - return true - } - - rv := reflect.ValueOf(v) - switch rv.Kind() { - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: - return rv.IsNil() - default: - return false - } -} - -// WithTrustedProxyPrefixes adds trusted proxy network prefixes. -func WithTrustedProxyPrefixes(prefixes ...netip.Prefix) Option { - prefixes = clonePrefixes(prefixes) - - return func(c *config) error { - normalized, err := normalizeTrustedProxyPrefixes(prefixes) - if err != nil { - return err - } - - appendTrustedProxyCIDRs(c, normalized...) - return nil - } -} - -// WithTrustedLoopbackProxy adds loopback CIDRs to trusted proxy ranges. -func WithTrustedLoopbackProxy() Option { - return func(c *config) error { - appendTrustedProxyCIDRs(c, loopbackProxyCIDRs...) - return nil - } -} - -// WithTrustedPrivateProxyRanges adds private network CIDRs to trusted proxy ranges. -func WithTrustedPrivateProxyRanges() Option { - return func(c *config) error { - appendTrustedProxyCIDRs(c, privateProxyCIDRs...) - return nil - } -} - -// WithTrustedLocalProxyDefaults adds loopback and private network CIDRs. -func WithTrustedLocalProxyDefaults() Option { - return func(c *config) error { - appendTrustedProxyCIDRs(c, loopbackProxyCIDRs...) - appendTrustedProxyCIDRs(c, privateProxyCIDRs...) - return nil - } -} - -// WithTrustedProxyAddrs adds trusted upstream proxy host addresses. -func WithTrustedProxyAddrs(addrs ...netip.Addr) Option { - addrs = cloneAddrs(addrs) - - return func(c *config) error { - prefixes := make([]netip.Prefix, 0, len(addrs)) - for _, addr := range addrs { - if !addr.IsValid() { - return fmt.Errorf("invalid proxy address %q", addr) - } - - addr = normalizeIP(addr) - prefixes = append(prefixes, netip.PrefixFrom(addr, addr.BitLen())) - } - - appendTrustedProxyCIDRs(c, prefixes...) - return nil - } -} - -// WithMinTrustedProxies sets the minimum trusted proxy count for chain-header sources. -func WithMinTrustedProxies(min int) Option { - return func(c *config) error { - c.minTrustedProxies = min - return nil - } -} - -// WithMaxTrustedProxies sets the maximum trusted proxy count for chain-header sources. -func WithMaxTrustedProxies(max int) Option { - return func(c *config) error { - c.maxTrustedProxies = max - return nil - } -} - -// WithAllowPrivateIPs configures whether private client IPs are accepted. -func WithAllowPrivateIPs(allow bool) Option { - return func(c *config) error { - c.allowPrivateIPs = allow - return nil - } -} - -// WithAllowedReservedClientPrefixes configures reserved client prefixes to explicitly allow. -func WithAllowedReservedClientPrefixes(prefixes ...netip.Prefix) Option { - prefixes = clonePrefixes(prefixes) - - return func(c *config) error { - normalized, err := normalizeReservedClientPrefixes(prefixes) - if err != nil { - return err - } - - c.allowReservedClientPrefixes = mergeUniquePrefixes(c.allowReservedClientPrefixes, normalized...) - return nil - } -} - -// WithMaxChainLength sets the maximum number of entries accepted in proxy chains. -func WithMaxChainLength(max int) Option { - return func(c *config) error { - c.maxChainLength = max - return nil - } -} - -// WithLogger sets the logger implementation used for warning events. -func WithLogger(logger Logger) Option { - return func(c *config) error { - c.logger = logger - return nil - } -} - -// WithMetrics sets a concrete metrics implementation. -// -// If previously configured, a metrics factory is disabled. -func WithMetrics(metrics Metrics) Option { - return func(c *config) error { - c.metrics = metrics - c.metricsFactory = nil - c.useMetricsFactory = false - return nil - } -} - -// WithMetricsFactory configures a lazy metrics constructor. -// -// The factory is invoked only for the final winning metrics option after -// option validation succeeds. -func WithMetricsFactory(factory func() (Metrics, error)) Option { - return func(c *config) error { - if factory == nil { - return fmt.Errorf("metrics factory cannot be nil") - } - - c.metricsFactory = factory - c.useMetricsFactory = true - return nil - } -} - -// WithSourcePriority sets extraction source order. -func WithSourcePriority(sources ...Source) Option { - resolvedSources := canonicalizeSources(slices.Clone(sources)) - - return func(c *config) error { - if len(resolvedSources) == 0 { - return fmt.Errorf("at least one source required in WithSourcePriority") - } - - c.sourcePriority = slices.Clone(resolvedSources) - return nil - } -} - -// WithChainSelection sets how client candidates are chosen from chain headers. -func WithChainSelection(selection ChainSelection) Option { - return func(c *config) error { - c.chainSelection = selection - return nil - } -} - -// WithDebugInfo controls whether chain-debug metadata is included in results. -func WithDebugInfo(enable bool) Option { - return func(c *config) error { - c.debugMode = enable - return nil - } -} - -// WithSecurityMode sets strict or lax fallback behavior after security errors. -func WithSecurityMode(mode SecurityMode) Option { - return func(c *config) error { - c.securityMode = mode - return nil - } -} diff --git a/parse_benchmark_test.go b/parse_benchmark_test.go new file mode 100644 index 0000000..f3fbac2 --- /dev/null +++ b/parse_benchmark_test.go @@ -0,0 +1,18 @@ +package clientip + +import "testing" + +func BenchmarkParseIP(b *testing.B) { + testCases := []string{"1.1.1.1", " 1.1.1.1 ", "1.1.1.1:8080", "[2606:4700:4700::1]", "[2606:4700:4700::1]:8080", `"1.1.1.1"`} + + for _, tc := range testCases { + b.Run(tc, func(b *testing.B) { + for i := 0; i < b.N; i++ { + ip := parseIP(tc) + if !ip.IsValid() { + b.Fatal("parsing failed") + } + } + }) + } +} diff --git a/chain_capacity.go b/parse_chain_capacity.go similarity index 85% rename from chain_capacity.go rename to parse_chain_capacity.go index d688951..c47d32a 100644 --- a/chain_capacity.go +++ b/parse_chain_capacity.go @@ -2,12 +2,9 @@ package clientip import "strings" -// typicalChainCapacity is the default initial capacity used when parsing proxy -// chains. const typicalChainCapacity = 8 -func (e *Extractor) chainPartsCapacity(values []string) int { - maxLength := e.config.maxChainLength +func chainPartsCapacity(values []string, maxLength int) int { if maxLength <= 0 { maxLength = 1 } diff --git a/parse_errors.go b/parse_errors.go new file mode 100644 index 0000000..09dfb82 --- /dev/null +++ b/parse_errors.go @@ -0,0 +1,12 @@ +package clientip + +import "fmt" + +type chainTooLongParseError struct { + ChainLength int + MaxLength int +} + +func (e *chainTooLongParseError) Error() string { + return fmt.Sprintf("proxy chain too long (chain_length=%d, max_length=%d)", e.ChainLength, e.MaxLength) +} diff --git a/forwarded.go b/parse_forwarded.go similarity index 65% rename from forwarded.go rename to parse_forwarded.go index 750d4bf..050ceb1 100644 --- a/forwarded.go +++ b/parse_forwarded.go @@ -6,20 +6,12 @@ import ( "strings" ) -// parseForwardedValues extracts the Forwarded for= chain from one or more -// Forwarded header values. -// -// Header values and elements are processed in wire order. Elements without a -// for parameter are ignored. Any parse failure is converted to an -// ErrInvalidForwardedHeader extraction error with SourceForwarded. -// -// The returned chain is bounded by the configured maxChainLength. -func (e *Extractor) parseForwardedValues(values []string) ([]string, error) { +func parseForwardedValues(values []string, maxChainLength int) ([]string, error) { if len(values) == 0 { return nil, nil } - parts := make([]string, 0, e.chainPartsCapacity(values)) + parts := make([]string, 0, chainPartsCapacity(values, maxChainLength)) for _, value := range values { err := scanForwardedSegments(value, ',', func(element string) error { @@ -31,37 +23,29 @@ func (e *Extractor) parseForwardedValues(values []string) ([]string, error) { return nil } - var appendErr error - parts, appendErr = e.appendChainPart(parts, forwardedFor, builtinSource(sourceForwarded)) - return appendErr + if len(parts) >= maxChainLength { + return &chainTooLongParseError{ + ChainLength: len(parts) + 1, + MaxLength: maxChainLength, + } + } + + parts = append(parts, forwardedFor) + return nil }) if err != nil { - if errors.Is(err, ErrChainTooLong) { + var chainErr *chainTooLongParseError + if errors.As(err, &chainErr) { return nil, err } - return nil, invalidForwardedHeaderError(err) + return nil, err } } return parts, nil } -// invalidForwardedHeaderError wraps low-level parse errors as an extraction -// error tagged with ErrInvalidForwardedHeader and SourceForwarded. -func invalidForwardedHeaderError(err error) error { - return &ExtractionError{ - Err: fmt.Errorf("%w: %w", ErrInvalidForwardedHeader, err), - Source: builtinSource(sourceForwarded), - } -} - -// parseForwardedElement parses a single Forwarded element and returns its for -// parameter value when present. -// -// It allows arbitrary additional parameters, treats the parameter name -// case-insensitively, and rejects duplicate for parameters in the same -// element. func parseForwardedElement(element string) (forwardedFor string, hasFor bool, err error) { err = scanForwardedSegments(element, ';', func(param string) error { eq := strings.IndexByte(param, '=') @@ -102,8 +86,6 @@ func parseForwardedElement(element string) (forwardedFor string, hasFor bool, er return forwardedFor, hasFor, nil } -// scanForwardedSegments splits value by delimiter while respecting quoted -// segments and escape sequences inside quoted strings. func scanForwardedSegments(value string, delimiter byte, onSegment func(string) error) error { start := 0 inQuotes := false @@ -153,10 +135,6 @@ func scanForwardedSegments(value string, delimiter byte, onSegment func(string) return nil } -// parseForwardedForValue parses a Forwarded for parameter value. -// -// The value may be an unquoted token or a quoted string. For quoted strings, -// escaping is handled by unquoteForwardedValue. func parseForwardedForValue(value string) (string, error) { value = strings.TrimSpace(value) if value == "" { @@ -178,8 +156,6 @@ func parseForwardedForValue(value string) (string, error) { return value, nil } -// unquoteForwardedValue removes surrounding quotes from a Forwarded quoted -// string and resolves backslash escapes. func unquoteForwardedValue(value string) (string, error) { if len(value) < 2 || value[0] != '"' || value[len(value)-1] != '"' { return "", fmt.Errorf("invalid quoted string %q", value) diff --git a/parse_forwarded_test.go b/parse_forwarded_test.go new file mode 100644 index 0000000..0e9a290 --- /dev/null +++ b/parse_forwarded_test.go @@ -0,0 +1,125 @@ +package clientip + +import ( + "errors" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestParseForwardedValues(t *testing.T) { + tests := []struct { + name string + values []string + want []string + wantErr bool + }{ + {name: "single for value", values: []string{"for=1.1.1.1"}, want: []string{"1.1.1.1"}}, + {name: "case-insensitive parameter name", values: []string{"For=1.1.1.1"}, want: []string{"1.1.1.1"}}, + {name: "multiple elements in one header", values: []string{"for=1.1.1.1, for=8.8.8.8"}, want: []string{"1.1.1.1", "8.8.8.8"}}, + {name: "multiple header lines", values: []string{"for=1.1.1.1", "for=8.8.8.8"}, want: []string{"1.1.1.1", "8.8.8.8"}}, + {name: "parameters with semicolons", values: []string{"for=1.1.1.1;proto=https;by=10.0.0.1"}, want: []string{"1.1.1.1"}}, + {name: "quoted IPv6 and port", values: []string{"for=\"[2606:4700:4700::1]:8080\""}, want: []string{"[2606:4700:4700::1]:8080"}}, + {name: "quoted comma is not treated as element delimiter", values: []string{"for=\"1.1.1.1,8.8.8.8\";proto=https"}, want: []string{"1.1.1.1,8.8.8.8"}}, + {name: "quoted semicolon is not treated as param delimiter", values: []string{"for=\"1.1.1.1;edge\";proto=https"}, want: []string{"1.1.1.1;edge"}}, + {name: "escaped quote remains inside quoted value", values: []string{`for="1.1.1.1\";edge";proto=https`}, want: []string{`1.1.1.1";edge`}}, + {name: "ignores element without for parameter", values: []string{"proto=https;by=10.0.0.1, for=8.8.8.8"}, want: []string{"8.8.8.8"}}, + {name: "invalid parameter format", values: []string{"for"}, wantErr: true}, + {name: "unterminated quoted string", values: []string{"for=\"1.1.1.1"}, wantErr: true}, + {name: "duplicate for parameter", values: []string{"for=1.1.1.1;for=8.8.8.8"}, wantErr: true}, + {name: "trailing escape in quoted value", values: []string{`for="1.1.1.1\`}, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseForwardedValues(tt.values, 100) + + if tt.wantErr { + if err == nil { + t.Fatalf("parseForwardedValues() error = nil, want parse error") + } + return + } + + if err != nil { + t.Fatalf("parseForwardedValues() error = %v, want nil", err) + } + + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatalf("parseForwardedValues() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestParseForwardedValues_MaxChainLength(t *testing.T) { + _, err := parseForwardedValues([]string{"for=1.1.1.1, for=2.2.2.2, for=3.3.3.3"}, 2) + + var chainErr *chainTooLongParseError + if !errors.As(err, &chainErr) { + t.Fatalf("parseForwardedValues() error = %v, want chainTooLongParseError", err) + } +} + +func TestParseForwardedValues_MalformedParameterMatrix(t *testing.T) { + tests := []struct { + name string + values []string + }{ + {name: "empty parameter key", values: []string{"=1.1.1.1"}}, + {name: "empty for value", values: []string{"for="}}, + {name: "empty quoted for value", values: []string{`for=""`}}, + {name: "invalid quoted for value suffix", values: []string{`for="1.1.1.1"extra`}}, + {name: "non for parameter missing equals", values: []string{"for=1.1.1.1;proto"}}, + {name: "non for parameter empty key", values: []string{"for=1.1.1.1;=https"}}, + {name: "non for parameter empty value", values: []string{"for=1.1.1.1;proto="}}, + {name: "unterminated quoted value across params", values: []string{"for=1.1.1.1;proto=\"https"}}, + {name: "unbalanced quotes in element", values: []string{`for="a"b"`}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := parseForwardedValues(tt.values, 100) + if err == nil { + t.Fatalf("parseForwardedValues() error = nil, want parse error") + } + }) + } +} + +func TestParseForwardedForValue(t *testing.T) { + tests := []struct { + name string + input string + want string + wantErr bool + }{ + {name: "unquoted token", input: "1.1.1.1", want: "1.1.1.1"}, + {name: "quoted token", input: `"1.1.1.1"`, want: "1.1.1.1"}, + {name: "quoted token with surrounding spaces", input: ` "1.1.1.1" `, want: "1.1.1.1"}, + {name: "escaped quote in quoted token", input: `"1.1.1.1\\\"edge"`, want: `1.1.1.1\"edge`}, + {name: "empty input", input: "", wantErr: true}, + {name: "spaces only", input: " ", wantErr: true}, + {name: "unterminated quote", input: `"1.1.1.1`, wantErr: true}, + {name: "unexpected inner quote", input: `"a"b"`, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseForwardedForValue(tt.input) + if tt.wantErr { + if err == nil { + t.Fatalf("parseForwardedForValue() error = nil, want error") + } + return + } + + if err != nil { + t.Fatalf("parseForwardedForValue() error = %v, want nil", err) + } + if got != tt.want { + t.Fatalf("parseForwardedForValue() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/parse_fuzz_test.go b/parse_fuzz_test.go new file mode 100644 index 0000000..7dd8129 --- /dev/null +++ b/parse_fuzz_test.go @@ -0,0 +1,117 @@ +package clientip + +import ( + "errors" + "testing" +) + +func FuzzParseIP_RoundTripNormalization(f *testing.F) { + for _, seed := range []string{"1.1.1.1", " 1.1.1.1 ", "1.1.1.1:443", "[2606:4700:4700::1]:443", `"1.1.1.1"`, `'1.1.1.1'`, "not-an-ip", ""} { + f.Add(seed) + } + + f.Fuzz(func(t *testing.T, raw string) { + parsed := parseIP(raw) + if !parsed.IsValid() { + return + } + + roundTrip := parseIP(parsed.String()) + if !roundTrip.IsValid() { + t.Fatalf("round-trip parse invalid for %q (%q)", raw, parsed.String()) + } + + if parsed.Unmap() != roundTrip.Unmap() { + t.Fatalf("normalized round-trip mismatch for %q", raw) + } + }) +} + +func FuzzParseRemoteAddr_RoundTripNormalization(f *testing.F) { + for _, seed := range []string{"1.1.1.1:443", "[2606:4700:4700::1]:443", "1.1.1.1", "2606:4700:4700::1", "example.com:443", ""} { + f.Add(seed) + } + + f.Fuzz(func(t *testing.T, raw string) { + parsed := parseRemoteAddr(raw) + if !parsed.IsValid() { + return + } + + roundTrip := parseIP(parsed.String()) + if !roundTrip.IsValid() { + t.Fatalf("round-trip parse invalid for remote addr %q (%q)", raw, parsed.String()) + } + + if parsed.Unmap() != roundTrip.Unmap() { + t.Fatalf("normalized round-trip mismatch for remote addr %q", raw) + } + }) +} + +func FuzzParseXFFValues_ErrorShapeAndOutput(f *testing.F) { + for _, seed := range []string{"1.1.1.1", "1.1.1.1, 8.8.8.8", "1.1.1.1, , 8.8.8.8", "\t1.1.1.1\t", ",", ", ,", ""} { + f.Add(seed) + } + + f.Fuzz(func(t *testing.T, raw string) { + valueSets := [][]string{{raw}, {raw, raw}, {"1.1.1.1", raw}, {raw, "8.8.8.8"}} + + for _, values := range valueSets { + parts, parseErr := parseXFFValues(values, 16) + + if parseErr != nil { + var chainErr *chainTooLongParseError + if !errors.As(parseErr, &chainErr) { + t.Fatalf("unexpected parseXFFValues error type for %#v: %v", values, parseErr) + } + continue + } + + if len(parts) > 16 { + t.Fatalf("parts length = %d, max = %d", len(parts), 16) + } + + for i, part := range parts { + if part == "" { + t.Fatalf("empty part at index %d", i) + } + if part != trimHTTPWhitespace(part) { + t.Fatalf("part has untrimmed HTTP whitespace at index %d: %q", i, part) + } + } + } + }) +} + +func FuzzParseForwardedValues_ErrorShapeAndOutput(f *testing.F) { + for _, seed := range []string{"for=1.1.1.1", "for=1.1.1.1, for=8.8.8.8", "for=1.1.1.1;proto=https", `for="[2606:4700:4700::1]:443"`, `for="1.1.1.1\"edge"`, "for", `for="unterminated`, ""} { + f.Add(seed) + } + + f.Fuzz(func(t *testing.T, raw string) { + valueSets := [][]string{{raw}, {raw, raw}, {"for=1.1.1.1", raw}, {raw, "for=8.8.8.8"}} + + for _, values := range valueSets { + parts, parseErr := parseForwardedValues(values, 16) + + if parseErr != nil { + var chainErr *chainTooLongParseError + if errors.As(parseErr, &chainErr) { + continue + } + continue + } + + if len(parts) > 16 { + t.Fatalf("parts length = %d, max = %d", len(parts), 16) + } + + for i, part := range parts { + if part == "" { + t.Fatalf("empty forwarded part at index %d", i) + } + } + } + }) +} diff --git a/ip_parse.go b/parse_ip.go similarity index 69% rename from ip_parse.go rename to parse_ip.go index 97b182e..78498bd 100644 --- a/ip_parse.go +++ b/parse_ip.go @@ -6,20 +6,40 @@ import ( "strings" ) -// parseIP extracts an IP address from various formats found in proxy headers. -// It handles: -// - Leading/trailing whitespace: " 192.168.1.1 " -// - Port suffixes: "192.168.1.1:8080" or "[::1]:8080" -// - Quoted values: "\"192.168.1.1\"" or "'192.168.1.1'" -// - IPv6 brackets: "[::1]" -// -// The function normalizes these common variations before calling -// netip.ParseAddr for the actual parsing. This approach is lenient with -// formatting (trimming, removing ports/quotes) but still relies on Go's -// standard IP validation. Validation of whether the IP is plausible (not -// loopback, private, etc.) is handled separately by isPlausibleClientIP. -// -// Returns an invalid netip.Addr (IsValid() == false) if parsing fails. +// normalizeIP unmaps IPv4-in-IPv6 addresses to their IPv4 form. +func normalizeIP(ip netip.Addr) netip.Addr { + if ip.Is4In6() { + return ip.Unmap() + } + + return ip +} + +// parseChainIP parses an IP from a chain value that has already been +// extracted and trimmed by a header parser (XFF, Forwarded). +// It handles plain IPs and the [ip]:port format from Forwarded headers, +// but skips the quote-stripping and fallback paths of parseIP. +func parseChainIP(s string) netip.Addr { + ip, err := netip.ParseAddr(s) + if err == nil { + return ip + } + + // Handle [ip]:port from Forwarded header values. + // Extract the content between [ and ] without importing net. + if len(s) > 2 && s[0] == '[' { + if end := strings.IndexByte(s, ']'); end > 1 { + ip, err = netip.ParseAddr(s[1:end]) + if err == nil { + return ip + } + } + } + + return netip.Addr{} +} + +// parseIP extracts an IP address from the formats commonly found in proxy headers. func parseIP(s string) netip.Addr { s = strings.TrimSpace(s) if s == "" { @@ -63,15 +83,7 @@ func parseIP(s string) netip.Addr { return ip } -func parseHostIP(host string) (netip.Addr, bool) { - ip, err := netip.ParseAddr(host) - if err == nil { - return ip, true - } - - return parseNormalizedIP(host) -} - +// parseRemoteAddr extracts an IP address from Request.RemoteAddr-like input. func parseRemoteAddr(s string) netip.Addr { host, ok := splitHostPortHost(s) if !ok { @@ -86,6 +98,15 @@ func parseRemoteAddr(s string) netip.Addr { return ip } +func parseHostIP(host string) (netip.Addr, bool) { + ip, err := netip.ParseAddr(host) + if err == nil { + return ip, true + } + + return parseNormalizedIP(host) +} + func looksLikeHostPort(s string) bool { if len(s) < 3 { return false @@ -127,14 +148,6 @@ func parseNormalizedIP(s string) (netip.Addr, bool) { return ip, true } -func normalizeIP(ip netip.Addr) netip.Addr { - if ip.Is4In6() { - return ip.Unmap() - } - return ip -} - -// trimMatchedPair removes one leading and trailing delimiter when both match. func trimMatchedPair(s string, start, end byte) string { if len(s) < 2 { return s @@ -147,7 +160,6 @@ func trimMatchedPair(s string, start, end byte) string { return s[1 : len(s)-1] } -// trimMatchedChar removes one matching leading and trailing character. func trimMatchedChar(s string, ch byte) string { return trimMatchedPair(s, ch, ch) } diff --git a/parse_ip_test.go b/parse_ip_test.go new file mode 100644 index 0000000..c774e60 --- /dev/null +++ b/parse_ip_test.go @@ -0,0 +1,174 @@ +package clientip + +import ( + "net/netip" + "testing" +) + +func TestParseIP(t *testing.T) { + tests := []struct { + name string + input string + want netip.Addr + wantErr bool + }{ + {name: "valid IPv4", input: "203.0.113.1", want: netip.MustParseAddr("203.0.113.1")}, + {name: "valid IPv4 with leading whitespace", input: " 203.0.113.1", want: netip.MustParseAddr("203.0.113.1")}, + {name: "valid IPv4 with trailing whitespace", input: "203.0.113.1 ", want: netip.MustParseAddr("203.0.113.1")}, + {name: "valid IPv4 with surrounding whitespace", input: " 203.0.113.1 ", want: netip.MustParseAddr("203.0.113.1")}, + {name: "valid IPv4 with tabs", input: "\t203.0.113.1\t", want: netip.MustParseAddr("203.0.113.1")}, + {name: "valid IPv4 with port", input: "203.0.113.1:8080", want: netip.MustParseAddr("203.0.113.1")}, + {name: "valid IPv4 with double quotes", input: `"203.0.113.1"`, want: netip.MustParseAddr("203.0.113.1")}, + {name: "valid IPv4 with single quotes", input: "'203.0.113.1'", want: netip.MustParseAddr("203.0.113.1")}, + {name: "valid IPv4 with quotes and port", input: `"203.0.113.1:8080"`, want: netip.MustParseAddr("203.0.113.1")}, + {name: "valid IPv6", input: "2001:db8::1", want: netip.MustParseAddr("2001:db8::1")}, + {name: "valid IPv6 with brackets", input: "[2001:db8::1]", want: netip.MustParseAddr("2001:db8::1")}, + {name: "valid IPv6 with brackets and port", input: "[2001:db8::1]:8080", want: netip.MustParseAddr("2001:db8::1")}, + {name: "valid IPv6 with whitespace and brackets", input: " [2001:db8::1] ", want: netip.MustParseAddr("2001:db8::1")}, + {name: "localhost IPv4", input: "127.0.0.1", want: netip.MustParseAddr("127.0.0.1")}, + {name: "localhost IPv4 with port", input: "127.0.0.1:8080", want: netip.MustParseAddr("127.0.0.1")}, + {name: "localhost IPv6", input: "::1", want: netip.MustParseAddr("::1")}, + {name: "localhost IPv6 with brackets and port", input: "[::1]:8080", want: netip.MustParseAddr("::1")}, + {name: "empty string", input: "", wantErr: true}, + {name: "whitespace only", input: " ", wantErr: true}, + {name: "quotes only", input: `""`, wantErr: true}, + {name: "unmatched leading double quote", input: `"203.0.113.1`, wantErr: true}, + {name: "unmatched trailing double quote", input: `203.0.113.1"`, wantErr: true}, + {name: "unmatched leading single quote", input: "'203.0.113.1", wantErr: true}, + {name: "unmatched trailing single quote", input: "203.0.113.1'", wantErr: true}, + {name: "invalid IP", input: "not-an-ip", wantErr: true}, + {name: "invalid IPv4", input: "999.999.999.999", wantErr: true}, + {name: "port only", input: ":8080", wantErr: true}, + {name: "brackets only", input: "[]", wantErr: true}, + {name: "unmatched leading bracket", input: "[2001:db8::1", wantErr: true}, + {name: "unmatched trailing bracket", input: "2001:db8::1]", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseIP(tt.input) + if tt.wantErr { + if got.IsValid() { + t.Errorf("parseIP(%q) = %v, want invalid", tt.input, got) + } + return + } + + if !got.IsValid() { + t.Errorf("parseIP(%q) = invalid, want %v", tt.input, tt.want) + return + } + if got != tt.want { + t.Errorf("parseIP(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func Test_parseRemoteAddr(t *testing.T) { + tests := []struct { + name string + input string + want netip.Addr + wantErr bool + }{ + {name: "ipv4 host:port", input: "203.0.113.1:8080", want: netip.MustParseAddr("203.0.113.1")}, + {name: "ipv6 host:port", input: "[2001:db8::1]:443", want: netip.MustParseAddr("2001:db8::1")}, + {name: "bare ipv4 fallback", input: "203.0.113.1", want: netip.MustParseAddr("203.0.113.1")}, + {name: "bare ipv6 fallback", input: "2001:db8::1", want: netip.MustParseAddr("2001:db8::1")}, + {name: "bracketed ipv6 fallback", input: "[2001:db8::1]", want: netip.MustParseAddr("2001:db8::1")}, + {name: "quoted ipv4 fallback", input: `"203.0.113.1"`, want: netip.MustParseAddr("203.0.113.1")}, + {name: "hostname with port", input: "example.com:443", wantErr: true}, + {name: "non-numeric port is ignored", input: "203.0.113.1:notaport", want: netip.MustParseAddr("203.0.113.1")}, + {name: "empty", input: "", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseRemoteAddr(tt.input) + if tt.wantErr { + if got.IsValid() { + t.Errorf("parseRemoteAddr(%q) = %v, want invalid", tt.input, got) + } + return + } + + if !got.IsValid() { + t.Errorf("parseRemoteAddr(%q) = invalid, want %v", tt.input, tt.want) + return + } + + if got != tt.want { + t.Errorf("parseRemoteAddr(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestNormalizeIP(t *testing.T) { + tests := []struct { + name string + input netip.Addr + want netip.Addr + }{ + {name: "IPv4 - no change", input: netip.MustParseAddr("203.0.113.1"), want: netip.MustParseAddr("203.0.113.1")}, + {name: "IPv6 - no change", input: netip.MustParseAddr("2001:db8::1"), want: netip.MustParseAddr("2001:db8::1")}, + {name: "IPv4-mapped IPv6 - unmapped", input: netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 203, 0, 113, 1}), want: netip.MustParseAddr("203.0.113.1")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeIP(tt.input) + if got != tt.want { + t.Errorf("normalizeIP(%v) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestTrimMatchedChar(t *testing.T) { + tests := []struct { + name string + input string + ch byte + want string + }{ + {name: "matching double quote delimiter", input: `"203.0.113.1"`, ch: '"', want: "203.0.113.1"}, + {name: "matching single quote delimiter", input: "'203.0.113.1'", ch: '\'', want: "203.0.113.1"}, + {name: "non-matching delimiter", input: "203.0.113.1", ch: '"', want: "203.0.113.1"}, + {name: "too short to trim", input: `"`, ch: '"', want: `"`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := trimMatchedChar(tt.input, tt.ch) + if got != tt.want { + t.Errorf("trimMatchedChar(%q, %q) = %q, want %q", tt.input, tt.ch, got, tt.want) + } + }) + } +} + +func TestTrimMatchedPair(t *testing.T) { + tests := []struct { + name string + input string + start byte + end byte + want string + }{ + {name: "matching pair", input: "[2001:db8::1]", start: '[', end: ']', want: "2001:db8::1"}, + {name: "unmatched leading bracket", input: "[2001:db8::1", start: '[', end: ']', want: "[2001:db8::1"}, + {name: "unmatched trailing bracket", input: "2001:db8::1]", start: '[', end: ']', want: "2001:db8::1]"}, + {name: "too short to trim", input: "[", start: '[', end: ']', want: "["}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := trimMatchedPair(tt.input, tt.start, tt.end) + if got != tt.want { + t.Errorf("trimMatchedPair(%q, %q, %q) = %q, want %q", tt.input, tt.start, tt.end, got, tt.want) + } + }) + } +} diff --git a/parse_remote_addr.go b/parse_remote_addr.go new file mode 100644 index 0000000..71727de --- /dev/null +++ b/parse_remote_addr.go @@ -0,0 +1,21 @@ +package clientip + +import "net/netip" + +// ParseRemoteAddr parses and normalizes Request.RemoteAddr-style input without +// applying extractor plausibility policy. +func ParseRemoteAddr(remoteAddr string) (netip.Addr, error) { + if remoteAddr == "" { + return netip.Addr{}, &ExtractionError{Err: ErrSourceUnavailable, Source: SourceRemoteAddr} + } + + ip := parseRemoteAddr(remoteAddr) + if !ip.IsValid() { + return netip.Addr{}, &RemoteAddrError{ + ExtractionError: ExtractionError{Err: ErrInvalidIP, Source: SourceRemoteAddr}, + RemoteAddr: remoteAddr, + } + } + + return normalizeIP(ip), nil +} diff --git a/parse_remote_addr_test.go b/parse_remote_addr_test.go new file mode 100644 index 0000000..153b7c2 --- /dev/null +++ b/parse_remote_addr_test.go @@ -0,0 +1,45 @@ +package clientip + +import ( + "errors" + "testing" +) + +func TestParseRemoteAddr(t *testing.T) { + tests := []struct { + name string + remoteAddr string + wantIP string + wantErr error + wantErrType any + }{ + {name: "host port", remoteAddr: "8.8.8.8:443", wantIP: "8.8.8.8"}, + {name: "bracketed ipv6 host port", remoteAddr: "[2001:db8::1]:443", wantIP: "2001:db8::1"}, + {name: "bare ip", remoteAddr: "2001:db8::1", wantIP: "2001:db8::1"}, + {name: "mapped ipv4 normalized", remoteAddr: "[::ffff:192.0.2.10]:443", wantIP: "192.0.2.10"}, + {name: "empty", wantErr: ErrSourceUnavailable, wantErrType: &ExtractionError{}}, + {name: "invalid", remoteAddr: "bad-remote-addr", wantErr: ErrInvalidIP, wantErrType: &RemoteAddrError{}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseRemoteAddr(tt.remoteAddr) + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Fatalf("error = %v, want %v", err, tt.wantErr) + } + if !errorIsType(err, tt.wantErrType) { + t.Fatalf("error type = %T, want %T", err, tt.wantErrType) + } + return + } + + if err != nil { + t.Fatalf("ParseRemoteAddr() error = %v", err) + } + if got.String() != tt.wantIP { + t.Fatalf("IP = %q, want %q", got, tt.wantIP) + } + }) + } +} diff --git a/parse_xff.go b/parse_xff.go new file mode 100644 index 0000000..83c92eb --- /dev/null +++ b/parse_xff.go @@ -0,0 +1,54 @@ +package clientip + +import "strings" + +func parseXFFValues(values []string, maxChainLength int) ([]string, error) { + if len(values) == 0 { + return nil, nil + } + + // Fast path: single header value with no commas and no surrounding whitespace. + // Return the input slice directly to avoid allocation. + if len(values) == 1 { + v := values[0] + if strings.IndexByte(v, ',') == -1 { + trimmed := trimHTTPWhitespace(v) + if trimmed == "" { + return nil, nil + } + if maxChainLength <= 0 { + return nil, &chainTooLongParseError{ChainLength: 1, MaxLength: maxChainLength} + } + if trimmed == v { + return values, nil + } + return []string{trimmed}, nil + } + } + + parts := make([]string, 0, chainPartsCapacity(values, maxChainLength)) + for _, v := range values { + start := 0 + for i := 0; i <= len(v); i++ { + if i != len(v) && v[i] != ',' { + continue + } + + part := trimHTTPWhitespace(v[start:i]) + if part != "" { + if len(parts) >= maxChainLength { + return nil, &chainTooLongParseError{ + ChainLength: len(parts) + 1, + MaxLength: maxChainLength, + } + } + + parts = append(parts, part) + } + + start = i + 1 + } + } + + return parts, nil +} diff --git a/parse_xff_test.go b/parse_xff_test.go new file mode 100644 index 0000000..7f36280 --- /dev/null +++ b/parse_xff_test.go @@ -0,0 +1,68 @@ +package clientip + +import ( + "errors" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestParseXFFValues(t *testing.T) { + tests := []struct { + name string + values []string + want []string + }{ + {name: "single value", values: []string{"1.1.1.1"}, want: []string{"1.1.1.1"}}, + {name: "single value with multiple IPs", values: []string{"1.1.1.1, 8.8.8.8"}, want: []string{"1.1.1.1", "8.8.8.8"}}, + {name: "multiple values combined", values: []string{"1.1.1.1", "8.8.8.8"}, want: []string{"1.1.1.1", "8.8.8.8"}}, + {name: "whitespace trimmed", values: []string{" 1.1.1.1 , 8.8.8.8 "}, want: []string{"1.1.1.1", "8.8.8.8"}}, + {name: "empty strings ignored", values: []string{"1.1.1.1, , 8.8.8.8"}, want: []string{"1.1.1.1", "8.8.8.8"}}, + {name: "empty list", values: []string{}, want: nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseXFFValues(tt.values, 100) + if err != nil { + t.Fatalf("parseXFFValues() error = %v, want nil", err) + } + + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatalf("parseXFFValues() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestParseXFFValues_MaxChainLength(t *testing.T) { + _, err := parseXFFValues([]string{"1.1.1.1, 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5, 6.6.6.6, 7.7.7.7"}, 5) + + var chainErr *chainTooLongParseError + if !errors.As(err, &chainErr) { + t.Fatalf("parseXFFValues() error = %v, want chainTooLongParseError", err) + } +} + +func TestParseXFFValues_PreservesWireOrderAcrossHeaderLines(t *testing.T) { + values := []string{"1.1.1.1, 8.8.8.8", "9.9.9.9", " 4.4.4.4 , 5.5.5.5 "} + + got, err := parseXFFValues(values, 10) + if err != nil { + t.Fatalf("parseXFFValues() error = %v", err) + } + + want := []string{"1.1.1.1", "8.8.8.8", "9.9.9.9", "4.4.4.4", "5.5.5.5"} + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("parseXFFValues() mismatch (-want +got):\n%s", diff) + } +} + +func TestParseXFFValues_MaxChainLength_AcrossHeaderLines(t *testing.T) { + _, err := parseXFFValues([]string{"1.1.1.1, 8.8.8.8", "9.9.9.9", "4.4.4.4"}, 3) + + var chainErr *chainTooLongParseError + if !errors.As(err, &chainErr) { + t.Fatalf("parseXFFValues() error = %v, want chainTooLongParseError", err) + } +} diff --git a/parser_fuzz_test.go b/parser_fuzz_test.go deleted file mode 100644 index 9269a3c..0000000 --- a/parser_fuzz_test.go +++ /dev/null @@ -1,168 +0,0 @@ -package clientip - -import ( - "errors" - "testing" -) - -func FuzzParseIP_RoundTripNormalization(f *testing.F) { - for _, seed := range []string{ - "1.1.1.1", - " 1.1.1.1 ", - "1.1.1.1:443", - "[2606:4700:4700::1]:443", - `"1.1.1.1"`, - `'1.1.1.1'`, - "not-an-ip", - "", - } { - f.Add(seed) - } - - f.Fuzz(func(t *testing.T, raw string) { - parsed := parseIP(raw) - if !parsed.IsValid() { - return - } - - roundTrip := parseIP(parsed.String()) - if !roundTrip.IsValid() { - t.Fatalf("round-trip parse invalid for %q (%q)", raw, parsed.String()) - } - - if normalizeIP(parsed) != normalizeIP(roundTrip) { - t.Fatalf("normalized round-trip mismatch for %q", raw) - } - }) -} - -func FuzzParseRemoteAddr_RoundTripNormalization(f *testing.F) { - for _, seed := range []string{ - "1.1.1.1:443", - "[2606:4700:4700::1]:443", - "1.1.1.1", - "2606:4700:4700::1", - "example.com:443", - "", - } { - f.Add(seed) - } - - f.Fuzz(func(t *testing.T, raw string) { - parsed := parseRemoteAddr(raw) - if !parsed.IsValid() { - return - } - - roundTrip := parseIP(parsed.String()) - if !roundTrip.IsValid() { - t.Fatalf("round-trip parse invalid for remote addr %q (%q)", raw, parsed.String()) - } - - if normalizeIP(parsed) != normalizeIP(roundTrip) { - t.Fatalf("normalized round-trip mismatch for remote addr %q", raw) - } - }) -} - -func FuzzParseXFFValues_ErrorShapeAndOutput(f *testing.F) { - extractor, err := New(WithMaxChainLength(16)) - if err != nil { - f.Fatalf("New() error = %v", err) - } - - for _, seed := range []string{ - "1.1.1.1", - "1.1.1.1, 8.8.8.8", - "1.1.1.1, , 8.8.8.8", - "\t1.1.1.1\t", - ",", - ", ,", - "", - } { - f.Add(seed) - } - - f.Fuzz(func(t *testing.T, raw string) { - valueSets := [][]string{ - {raw}, - {raw, raw}, - {"1.1.1.1", raw}, - {raw, "8.8.8.8"}, - } - - for _, values := range valueSets { - parts, parseErr := extractor.parseXFFValues(values) - - if parseErr != nil { - if !errors.Is(parseErr, ErrChainTooLong) { - t.Fatalf("unexpected parseXFFValues error type for %#v: %v", values, parseErr) - } - continue - } - - if len(parts) > extractor.config.maxChainLength { - t.Fatalf("parts length = %d, max = %d", len(parts), extractor.config.maxChainLength) - } - - for i, part := range parts { - if part == "" { - t.Fatalf("empty part at index %d", i) - } - if part != trimHTTPWhitespace(part) { - t.Fatalf("part has untrimmed HTTP whitespace at index %d: %q", i, part) - } - } - } - }) -} - -func FuzzParseForwardedValues_ErrorShapeAndOutput(f *testing.F) { - extractor, err := New(WithMaxChainLength(16)) - if err != nil { - f.Fatalf("New() error = %v", err) - } - - for _, seed := range []string{ - "for=1.1.1.1", - "for=1.1.1.1, for=8.8.8.8", - "for=1.1.1.1;proto=https", - `for="[2606:4700:4700::1]:443"`, - `for="1.1.1.1\"edge"`, - "for", - `for="unterminated`, - "", - } { - f.Add(seed) - } - - f.Fuzz(func(t *testing.T, raw string) { - valueSets := [][]string{ - {raw}, - {raw, raw}, - {"for=1.1.1.1", raw}, - {raw, "for=8.8.8.8"}, - } - - for _, values := range valueSets { - parts, parseErr := extractor.parseForwardedValues(values) - - if parseErr != nil { - if !errors.Is(parseErr, ErrInvalidForwardedHeader) && !errors.Is(parseErr, ErrChainTooLong) { - t.Fatalf("unexpected parseForwardedValues error type for %#v: %v", values, parseErr) - } - continue - } - - if len(parts) > extractor.config.maxChainLength { - t.Fatalf("parts length = %d, max = %d", len(parts), extractor.config.maxChainLength) - } - - for i, part := range parts { - if part == "" { - t.Fatalf("empty forwarded part at index %d", i) - } - } - } - }) -} diff --git a/presets.go b/presets.go index 60956b8..62b1d75 100644 --- a/presets.go +++ b/presets.go @@ -1,53 +1,35 @@ package clientip -// PresetDirectConnection configures extraction for direct client-to-app +// PresetDirectConnection configures strict extraction for direct client-to-app // traffic. // // This preset extracts from RemoteAddr only. -func PresetDirectConnection() Option { - return WithSourcePriority(builtinSource(sourceRemoteAddr)) +func PresetDirectConnection() Config { + cfg := DefaultConfig() + cfg.Sources = []Source{builtinSource(sourceRemoteAddr)} + return cfg } // PresetLoopbackReverseProxy configures extraction for apps behind a reverse // proxy on the same host (for example NGINX on localhost). // -// It trusts loopback proxy CIDRs and uses X-Forwarded-For with RemoteAddr -// fallback. -func PresetLoopbackReverseProxy() Option { - return func(c *config) error { - return applyOptions(c, - WithTrustedLoopbackProxy(), - WithSourcePriority(builtinSource(sourceXForwardedFor), builtinSource(sourceRemoteAddr)), - ) - } +// It trusts loopback proxy CIDRs and prioritizes X-Forwarded-For before +// RemoteAddr within the extractor's strict source order. +func PresetLoopbackReverseProxy() Config { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{builtinSource(sourceXForwardedFor), builtinSource(sourceRemoteAddr)} + return cfg } // PresetVMReverseProxy configures extraction for apps behind a reverse proxy // in a typical VM or private-network setup. // -// It trusts loopback and private proxy CIDRs and uses X-Forwarded-For with -// RemoteAddr fallback. -func PresetVMReverseProxy() Option { - return func(c *config) error { - return applyOptions(c, - WithTrustedLocalProxyDefaults(), - WithSourcePriority(builtinSource(sourceXForwardedFor), builtinSource(sourceRemoteAddr)), - ) - } -} - -// PresetPreferredHeaderThenXFFLax configures extraction to prefer a single -// custom header, then fall back to X-Forwarded-For and RemoteAddr. -// -// It also enables SecurityModeLax so invalid values in the preferred header -// can fall through to lower-priority sources. -// -// Header-based sources still require trusted proxy CIDRs. -func PresetPreferredHeaderThenXFFLax(header string) Option { - return func(c *config) error { - return applyOptions(c, - WithSourcePriority(HeaderSource(header), builtinSource(sourceXForwardedFor), builtinSource(sourceRemoteAddr)), - WithSecurityMode(SecurityModeLax), - ) - } +// It trusts loopback and private proxy CIDRs and prioritizes X-Forwarded-For +// before RemoteAddr within the extractor's strict source order. +func PresetVMReverseProxy() Config { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LocalProxyPrefixes() + cfg.Sources = []Source{builtinSource(sourceXForwardedFor), builtinSource(sourceRemoteAddr)} + return cfg } diff --git a/presets_test.go b/presets_test.go index 53711d9..d8bd3eb 100644 --- a/presets_test.go +++ b/presets_test.go @@ -1,7 +1,6 @@ package clientip import ( - "net/http" "strings" "testing" @@ -11,13 +10,13 @@ import ( func TestPresets_Config(t *testing.T) { tests := []struct { name string - opts []Option + cfg Config want configSnapshot wantErrText string }{ { name: "direct connection", - opts: []Option{PresetDirectConnection()}, + cfg: PresetDirectConnection(), want: configSnapshot{ TrustedProxyCIDRs: []string{}, MinTrustedProxies: 0, @@ -26,14 +25,13 @@ func TestPresets_Config(t *testing.T) { AllowReservedPrefixes: []string{}, MaxChainLength: DefaultMaxChainLength, ChainSelection: RightmostUntrustedIP, - SecurityMode: SecurityModeStrict, DebugMode: false, SourcePriority: []string{SourceRemoteAddr.String()}, }, }, { name: "loopback reverse proxy", - opts: []Option{PresetLoopbackReverseProxy()}, + cfg: PresetLoopbackReverseProxy(), want: configSnapshot{ TrustedProxyCIDRs: []string{"127.0.0.0/8", "::1/128"}, MinTrustedProxies: 0, @@ -42,14 +40,13 @@ func TestPresets_Config(t *testing.T) { AllowReservedPrefixes: []string{}, MaxChainLength: DefaultMaxChainLength, ChainSelection: RightmostUntrustedIP, - SecurityMode: SecurityModeStrict, DebugMode: false, SourcePriority: []string{SourceXForwardedFor.String(), SourceRemoteAddr.String()}, }, }, { name: "vm reverse proxy", - opts: []Option{PresetVMReverseProxy()}, + cfg: PresetVMReverseProxy(), want: configSnapshot{ TrustedProxyCIDRs: []string{"127.0.0.0/8", "::1/128", "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "fc00::/7"}, MinTrustedProxies: 0, @@ -58,24 +55,15 @@ func TestPresets_Config(t *testing.T) { AllowReservedPrefixes: []string{}, MaxChainLength: DefaultMaxChainLength, ChainSelection: RightmostUntrustedIP, - SecurityMode: SecurityModeStrict, DebugMode: false, SourcePriority: []string{SourceXForwardedFor.String(), SourceRemoteAddr.String()}, }, }, - { - name: "preferred header then xff lax invalid header", - opts: []Option{ - WithTrustedLoopbackProxy(), - PresetPreferredHeaderThenXFFLax(" "), - }, - wantErrText: "source names cannot be empty", - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - extractor, err := New(tt.opts...) + extractor, err := New(tt.cfg) if tt.wantErrText != "" { if err == nil { t.Fatalf("New() error = nil, want containing %q", tt.wantErrText) @@ -96,52 +84,3 @@ func TestPresets_Config(t *testing.T) { }) } } - -func TestPresetPreferredHeaderThenXFFLax_EndToEnd(t *testing.T) { - extractor, err := New( - WithTrustedLoopbackProxy(), - PresetPreferredHeaderThenXFFLax("X-Frontend-IP"), - ) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - req := &http.Request{ - RemoteAddr: "127.0.0.1:8080", - Header: make(http.Header), - } - req.Header.Set("X-Frontend-IP", "not-an-ip") - req.Header.Set("X-Forwarded-For", "8.8.8.8") - - extraction, err := extractor.Extract(req) - if err != nil { - t.Fatalf("Extract() error = %v", err) - } - - want := struct { - IP string - Source Source - TrustedProxyCount int - HasDebugInfo bool - }{ - IP: "8.8.8.8", - Source: SourceXForwardedFor, - TrustedProxyCount: 0, - HasDebugInfo: false, - } - got := struct { - IP string - Source Source - TrustedProxyCount int - HasDebugInfo bool - }{ - IP: extraction.IP.String(), - Source: extraction.Source, - TrustedProxyCount: extraction.TrustedProxyCount, - HasDebugInfo: extraction.DebugInfo != nil, - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Fatalf("extraction mismatch (-want +got):\n%s", diff) - } -} diff --git a/request_input.go b/request_input.go deleted file mode 100644 index 69134e4..0000000 --- a/request_input.go +++ /dev/null @@ -1,153 +0,0 @@ -package clientip - -import ( - "context" - "net/http" - "net/url" -) - -// HeaderValues provides access to request header values by name. -// -// Implementations should return one slice entry per received header line. -// Single-IP sources rely on per-line values to detect duplicates, and chain -// sources preserve wire order across repeated lines. -// -// Header names are requested in canonical MIME format (for example -// "X-Forwarded-For"). -// -// net/http's http.Header satisfies this interface directly. -type HeaderValues interface { - Values(name string) []string -} - -// HeaderValuesFunc adapts a function to the HeaderValues interface. -type HeaderValuesFunc func(name string) []string - -// Values implements HeaderValues. -func (f HeaderValuesFunc) Values(name string) []string { - if f == nil { - return nil - } - - return f(name) -} - -// RequestInput provides framework-agnostic request data for extraction. -// -// Context defaults to context.Background() when nil. -// -// For Headers, preserve repeated header lines as separate values for each -// header name (for example two X-Forwarded-For lines should yield a slice with -// length 2, and two X-Real-IP lines should also yield length 2). -type RequestInput struct { - Context context.Context - RemoteAddr string - Path string - Headers HeaderValues -} - -func requestInputContext(input RequestInput) context.Context { - if input.Context == nil { - return context.Background() - } - - return input.Context -} - -func requestFromInput(input RequestInput, sourceHeaderKeys []string) *http.Request { - req := &http.Request{RemoteAddr: input.RemoteAddr} - if input.Path != "" { - req.URL = &url.URL{Path: input.Path} - } - - if input.Headers == nil { - return req - } - - if h, ok := input.Headers.(http.Header); ok { - req.Header = h - return req - } - - if h, ok := input.Headers.(*http.Header); ok && h != nil { - req.Header = *h - return req - } - - if h, ok := input.Headers.(HeaderValuesFunc); ok { - if h == nil { - return req - } - - input.Headers = h - } else if isNilValue(input.Headers) { - return req - } - - if len(sourceHeaderKeys) == 0 { - return req - } - if len(sourceHeaderKeys) == 1 { - key := sourceHeaderKeys[0] - values := input.Headers.Values(key) - if len(values) > 0 { - req.Header = http.Header{key: values} - } - - return req - } - - var headers http.Header - for _, key := range sourceHeaderKeys { - values := input.Headers.Values(key) - if len(values) == 0 { - continue - } - if headers == nil { - headers = make(http.Header, len(sourceHeaderKeys)) - } - - headers[key] = values - } - - if headers != nil { - req.Header = headers - } - - return req -} - -func sourceHeaderKeys(sourcePriority []Source) []string { - keys := make([]string, 0, len(sourcePriority)) - seen := make(map[string]struct{}, len(sourcePriority)) - - for _, source := range sourcePriority { - key, ok := sourceHeaderKey(source) - if !ok { - continue - } - - if _, duplicate := seen[key]; duplicate { - continue - } - - seen[key] = struct{}{} - keys = append(keys, key) - } - - return keys -} - -func sourceHeaderKey(source Source) (string, bool) { - source = canonicalSource(source) - if !source.valid() { - return "", false - } - - key, ok := source.headerKey() - if !ok { - return "", false - } - - return key, true -} diff --git a/resolver.go b/resolver.go new file mode 100644 index 0000000..7559607 --- /dev/null +++ b/resolver.go @@ -0,0 +1,315 @@ +package clientip + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/netip" + "sync" +) + +var errNilResolverExtractor = errors.New("resolver extractor cannot be nil") + +type resolverStateContextKey struct{} + +type resolutionSlot struct { + set bool + value Resolution +} + +// resolverState caches strict and preferred resolutions for one request. +// The mutex is held during compute() to guarantee at-most-once extraction; +// concurrent callers on the same request block until the first completes. +type resolverState struct { + mu sync.Mutex + strict resolutionSlot + preferred resolutionSlot +} + +// PreferredFallback controls which explicit fallback ResolvePreferred applies +// after strict extraction fails. +type PreferredFallback uint8 + +const ( + // PreferredFallbackNone leaves ResolvePreferred without a fallback path. + PreferredFallbackNone PreferredFallback = iota + // PreferredFallbackRemoteAddr falls back to parsed RemoteAddr. + PreferredFallbackRemoteAddr + // PreferredFallbackStaticIP falls back to StaticFallbackIP. + PreferredFallbackStaticIP +) + +func (f PreferredFallback) valid() bool { + return f == PreferredFallbackNone || f == PreferredFallbackRemoteAddr || f == PreferredFallbackStaticIP +} + +// ResolverConfig configures Resolver preferred fallback behavior. +type ResolverConfig struct { + // PreferredFallback selects which explicit fallback ResolvePreferred applies + // after strict extraction fails. + PreferredFallback PreferredFallback + // StaticFallbackIP is required when PreferredFallback is + // PreferredFallbackStaticIP. + StaticFallbackIP netip.Addr +} + +// Resolution captures a resolver result, including fallback metadata. +type Resolution struct { + Extraction + Err error + FallbackUsed bool +} + +// OK reports whether the resolution produced a usable IP without error. +func (r Resolution) OK() bool { + return r.Err == nil && r.IP.IsValid() +} + +// Resolver orchestrates strict and preferred resolution on top of Extractor. +type Resolver struct { + extractor *Extractor + config ResolverConfig +} + +// NewResolver creates a Resolver for a reusable Extractor. +func NewResolver(extractor *Extractor, config ResolverConfig) (*Resolver, error) { + if extractor == nil { + return nil, fmt.Errorf("invalid resolver configuration: %w", errNilResolverExtractor) + } + + config.StaticFallbackIP = normalizeIP(config.StaticFallbackIP) + if !config.PreferredFallback.valid() { + return nil, fmt.Errorf("invalid resolver configuration: unsupported preferred fallback %d", config.PreferredFallback) + } + + switch config.PreferredFallback { + case PreferredFallbackNone, PreferredFallbackRemoteAddr: + if config.StaticFallbackIP.IsValid() { + return nil, fmt.Errorf("invalid resolver configuration: StaticFallbackIP requires PreferredFallbackStaticIP") + } + case PreferredFallbackStaticIP: + if !config.StaticFallbackIP.IsValid() { + return nil, fmt.Errorf("invalid resolver configuration: PreferredFallbackStaticIP requires StaticFallbackIP") + } + } + + return &Resolver{extractor: extractor, config: config}, nil +} + +// ResolveStrict resolves client IP information without fallback. +func (r *Resolver) ResolveStrict(req *http.Request) (*http.Request, Resolution) { + if r == nil || r.extractor == nil { + return req, Resolution{Err: errNilResolverExtractor} + } + if req == nil { + return nil, Resolution{Err: ErrNilRequest} + } + + req, state := requestWithResolverState(req) + resolution := state.ResolveStrict(func() Resolution { return r.resolveStrictRequest(req) }) + return req, resolution +} + +func (s *resolverState) ResolveStrict(compute func() Resolution) Resolution { + s.mu.Lock() + defer s.mu.Unlock() + + if s.strict.set { + return s.strict.value + } + + value := compute() + s.strict = resolutionSlot{set: true, value: value} + return value +} + +func (s *resolverState) ResolvePreferred( + computeStrict func() Resolution, + shouldFallback func(Resolution) bool, + fallback func(Resolution) (Resolution, bool), +) Resolution { + s.mu.Lock() + defer s.mu.Unlock() + + if s.preferred.set { + return s.preferred.value + } + + strict := s.strict.value + if !s.strict.set { + strict = computeStrict() + s.strict = resolutionSlot{set: true, value: strict} + } + + preferred := strict + if shouldFallback != nil && shouldFallback(strict) && fallback != nil { + if resolved, ok := fallback(strict); ok { + preferred = resolved + } + } + + s.preferred = resolutionSlot{set: true, value: preferred} + return preferred +} + +func (s *resolverState) StrictValue() (Resolution, bool) { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.strict.set { + return Resolution{}, false + } + + return s.strict.value, true +} + +func (s *resolverState) PreferredValue() (Resolution, bool) { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.preferred.set { + return Resolution{}, false + } + + return s.preferred.value, true +} + +// ResolvePreferred resolves client IP information using the configured +// preferred fallback policy after strict extraction fails. +func (r *Resolver) ResolvePreferred(req *http.Request) (*http.Request, Resolution) { + if r == nil || r.extractor == nil { + return req, Resolution{Err: errNilResolverExtractor} + } + if req == nil { + return nil, Resolution{Err: ErrNilRequest} + } + + req, state := requestWithResolverState(req) + preferred := state.ResolvePreferred( + func() Resolution { return r.resolveStrictRequest(req) }, + func(strict Resolution) bool { + return strict.Err != nil && !isResolverTerminalContextError(strict.Err) + }, + func(Resolution) (Resolution, bool) { return r.preferredFallback(req.RemoteAddr) }, + ) + return req, preferred +} + +// ResolveInputStrict resolves client IP information from framework-agnostic input without fallback. +func (r *Resolver) ResolveInputStrict(input Input) (Input, Resolution) { + if r == nil || r.extractor == nil { + return input, Resolution{Err: errNilResolverExtractor} + } + + input, state := inputWithResolverState(input) + resolution := state.ResolveStrict(func() Resolution { return r.resolveStrictInput(input) }) + return input, resolution +} + +// ResolveInputPreferred resolves client IP information from framework-agnostic +// input using the configured preferred fallback policy after strict extraction +// fails. +func (r *Resolver) ResolveInputPreferred(input Input) (Input, Resolution) { + if r == nil || r.extractor == nil { + return input, Resolution{Err: errNilResolverExtractor} + } + + input, state := inputWithResolverState(input) + preferred := state.ResolvePreferred( + func() Resolution { return r.resolveStrictInput(input) }, + func(strict Resolution) bool { + return strict.Err != nil && !isResolverTerminalContextError(strict.Err) + }, + func(Resolution) (Resolution, bool) { return r.preferredFallback(input.RemoteAddr) }, + ) + return input, preferred +} + +// StrictResolutionFromContext returns the cached strict resolution, if present. +func StrictResolutionFromContext(ctx context.Context) (Resolution, bool) { + state, ok := resolverStateFromContext(ctx) + if !ok { + return Resolution{}, false + } + + return state.StrictValue() +} + +// PreferredResolutionFromContext returns the cached preferred resolution, if present. +func PreferredResolutionFromContext(ctx context.Context) (Resolution, bool) { + state, ok := resolverStateFromContext(ctx) + if !ok { + return Resolution{}, false + } + + return state.PreferredValue() +} + +func (r *Resolver) resolveStrictRequest(req *http.Request) Resolution { + extraction, err := r.extractor.Extract(req) + return Resolution{Extraction: extraction, Err: err} +} + +func (r *Resolver) resolveStrictInput(input Input) Resolution { + extraction, err := r.extractor.ExtractInput(input) + return Resolution{Extraction: extraction, Err: err} +} + +func (r *Resolver) preferredFallback(remoteAddr string) (Resolution, bool) { + switch r.config.PreferredFallback { + case PreferredFallbackRemoteAddr: + ip, err := ParseRemoteAddr(remoteAddr) + if err == nil { + return Resolution{ + Extraction: Extraction{IP: ip, Source: SourceRemoteAddr}, + FallbackUsed: true, + }, true + } + case PreferredFallbackStaticIP: + return Resolution{ + Extraction: Extraction{IP: r.config.StaticFallbackIP, Source: SourceStaticFallback}, + FallbackUsed: true, + }, true + } + + return Resolution{}, false +} + +func resolverStateFromContext(ctx context.Context) (*resolverState, bool) { + if ctx == nil { + return nil, false + } + + state, ok := ctx.Value(resolverStateContextKey{}).(*resolverState) + if !ok || state == nil { + return nil, false + } + + return state, true +} + +func requestWithResolverState(req *http.Request) (*http.Request, *resolverState) { + if state, ok := resolverStateFromContext(req.Context()); ok { + return req, state + } + + state := &resolverState{} + return req.WithContext(context.WithValue(req.Context(), resolverStateContextKey{}, state)), state +} + +func inputWithResolverState(input Input) (Input, *resolverState) { + ctx := requestInputContext(input) + if state, ok := resolverStateFromContext(ctx); ok { + input.Context = ctx + return input, state + } + + state := &resolverState{} + input.Context = context.WithValue(ctx, resolverStateContextKey{}, state) + return input, state +} + +func isResolverTerminalContextError(err error) bool { + return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) +} diff --git a/resolver_test.go b/resolver_test.go new file mode 100644 index 0000000..77389d8 --- /dev/null +++ b/resolver_test.go @@ -0,0 +1,590 @@ +package clientip + +import ( + "context" + "errors" + "net/http" + "net/netip" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +type countingResolverSource struct { + calls int + result Extraction + err error + extractFn func(requestView) (Extraction, error) + sourceName Source +} + +func (s *countingResolverSource) extract(req requestView) (Extraction, error) { + s.calls++ + if s.extractFn != nil { + return s.extractFn(req) + } + + return s.result, s.err +} + +func (s *countingResolverSource) name() string { + return "counting" +} + +func (s *countingResolverSource) sourceInfo() Source { + if s.sourceName.valid() { + return s.sourceName + } + if s.result.Source.valid() { + return s.result.Source + } + + return SourceRemoteAddr +} + +func newResolverTestExtractor(source sourceExtractor) *Extractor { + return &Extractor{ + config: &config{ + sourcePriority: []Source{HeaderSource("X-Test-IP")}, + sourceHeaderKeys: []string{"X-Test-IP"}, + }, + source: source, + } +} + +func mustNewResolver(t *testing.T, extractor *Extractor, config ResolverConfig) *Resolver { + t.Helper() + + resolver, err := NewResolver(extractor, config) + if err != nil { + t.Fatalf("NewResolver() error = %v", err) + } + + return resolver +} + +func TestNewResolver_InvalidConfig(t *testing.T) { + extractor := newResolverTestExtractor(&countingResolverSource{}) + + tests := []struct { + name string + config ResolverConfig + wantErrText string + }{ + { + name: "unsupported preferred fallback", + config: ResolverConfig{PreferredFallback: PreferredFallback(99)}, + wantErrText: "unsupported preferred fallback", + }, + { + name: "static fallback requires static IP", + config: ResolverConfig{PreferredFallback: PreferredFallbackStaticIP}, + wantErrText: "PreferredFallbackStaticIP requires StaticFallbackIP", + }, + { + name: "remote addr fallback rejects static IP", + config: ResolverConfig{ + PreferredFallback: PreferredFallbackRemoteAddr, + StaticFallbackIP: netip.MustParseAddr("0.0.0.0"), + }, + wantErrText: "StaticFallbackIP requires PreferredFallbackStaticIP", + }, + { + name: "no fallback rejects static IP", + config: ResolverConfig{ + StaticFallbackIP: netip.MustParseAddr("0.0.0.0"), + }, + wantErrText: "StaticFallbackIP requires PreferredFallbackStaticIP", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewResolver(extractor, tt.config) + if err == nil { + t.Fatalf("NewResolver() error = nil, want containing %q", tt.wantErrText) + } + if !strings.Contains(err.Error(), tt.wantErrText) { + t.Fatalf("NewResolver() error = %q, want containing %q", err.Error(), tt.wantErrText) + } + }) + } +} + +func TestResolver_ResolveStrict_CachesSuccess(t *testing.T) { + source := &countingResolverSource{ + result: Extraction{IP: netip.MustParseAddr("8.8.8.8"), Source: SourceXRealIP}, + } + resolver := mustNewResolver(t, newResolverTestExtractor(source), ResolverConfig{}) + + req := &http.Request{RemoteAddr: "203.0.113.10:443", Header: make(http.Header)} + req, first := resolver.ResolveStrict(req) + req, second := resolver.ResolveStrict(req) + + if source.calls != 1 { + t.Fatalf("extract calls = %d, want 1", source.calls) + } + if first.Err != nil { + t.Fatalf("first error = %v, want nil", first.Err) + } + if second.Err != nil { + t.Fatalf("second error = %v, want nil", second.Err) + } + if got, want := second.IP, netip.MustParseAddr("8.8.8.8"); got != want { + t.Fatalf("IP = %s, want %s", got, want) + } + if got, want := second.Source, SourceXRealIP; got != want { + t.Fatalf("Source = %q, want %q", got, want) + } + if second.FallbackUsed { + t.Fatal("FallbackUsed = true, want false") + } + + cached, ok := StrictResolutionFromContext(req.Context()) + if !ok { + t.Fatal("StrictResolutionFromContext() found no cached resolution") + } + if cached != second { + t.Fatalf("cached resolution = %#v, want %#v", cached, second) + } + if _, ok := PreferredResolutionFromContext(req.Context()); ok { + t.Fatal("PreferredResolutionFromContext() = true, want false") + } +} + +func TestResolver_ResolveStrict_CachesFailure(t *testing.T) { + strictErr := &ExtractionError{Err: ErrInvalidIP, Source: SourceXRealIP} + source := &countingResolverSource{err: strictErr} + resolver := mustNewResolver(t, newResolverTestExtractor(source), ResolverConfig{}) + + req := &http.Request{RemoteAddr: "203.0.113.10:443", Header: make(http.Header)} + req, first := resolver.ResolveStrict(req) + req, second := resolver.ResolveStrict(req) + + if source.calls != 1 { + t.Fatalf("extract calls = %d, want 1", source.calls) + } + if !errors.Is(first.Err, ErrInvalidIP) { + t.Fatalf("first error = %v, want ErrInvalidIP", first.Err) + } + if !errors.Is(second.Err, ErrInvalidIP) { + t.Fatalf("second error = %v, want ErrInvalidIP", second.Err) + } + if second.OK() { + t.Fatal("OK() = true, want false") + } + if got, want := second.Source, SourceXRealIP; got != want { + t.Fatalf("Source = %q, want %q", got, want) + } + + cached, ok := StrictResolutionFromContext(req.Context()) + if !ok { + t.Fatal("StrictResolutionFromContext() found no cached resolution") + } + if !errors.Is(cached.Err, ErrInvalidIP) { + t.Fatalf("cached error = %v, want ErrInvalidIP", cached.Err) + } +} + +func TestResolver_ResolvePreferred_ReusesStrictCachedResult(t *testing.T) { + source := &countingResolverSource{ + result: Extraction{IP: netip.MustParseAddr("8.8.8.8"), Source: SourceXRealIP}, + } + resolver := mustNewResolver(t, newResolverTestExtractor(source), ResolverConfig{PreferredFallback: PreferredFallbackRemoteAddr}) + + req := &http.Request{RemoteAddr: "127.0.0.1:8080", Header: make(http.Header)} + req, strict := resolver.ResolveStrict(req) + req, preferred := resolver.ResolvePreferred(req) + + if source.calls != 1 { + t.Fatalf("extract calls = %d, want 1", source.calls) + } + if strict != preferred { + t.Fatalf("preferred resolution = %#v, want %#v", preferred, strict) + } + if preferred.FallbackUsed { + t.Fatal("FallbackUsed = true, want false") + } + + cached, ok := PreferredResolutionFromContext(req.Context()) + if !ok { + t.Fatal("PreferredResolutionFromContext() found no cached resolution") + } + if cached != preferred { + t.Fatalf("cached preferred resolution = %#v, want %#v", cached, preferred) + } +} + +func TestResolver_ResolvePreferred_ParseRemoteAddrFallback(t *testing.T) { + strictErr := &ExtractionError{Err: ErrInvalidIP, Source: SourceXRealIP} + source := &countingResolverSource{err: strictErr} + resolver := mustNewResolver(t, newResolverTestExtractor(source), ResolverConfig{PreferredFallback: PreferredFallbackRemoteAddr}) + + req := &http.Request{RemoteAddr: "127.0.0.1:8080", Header: make(http.Header)} + req, resolution := resolver.ResolvePreferred(req) + + if source.calls != 1 { + t.Fatalf("extract calls = %d, want 1", source.calls) + } + if !resolution.OK() { + t.Fatalf("ResolvePreferred() error = %v", resolution.Err) + } + if !resolution.FallbackUsed { + t.Fatal("FallbackUsed = false, want true") + } + if got, want := resolution.IP, netip.MustParseAddr("127.0.0.1"); got != want { + t.Fatalf("IP = %s, want %s", got, want) + } + if got, want := resolution.Source, SourceRemoteAddr; got != want { + t.Fatalf("Source = %q, want %q", got, want) + } + + strict, ok := StrictResolutionFromContext(req.Context()) + if !ok { + t.Fatal("StrictResolutionFromContext() found no cached strict resolution") + } + if !errors.Is(strict.Err, ErrInvalidIP) { + t.Fatalf("strict error = %v, want ErrInvalidIP", strict.Err) + } + + preferred, ok := PreferredResolutionFromContext(req.Context()) + if !ok { + t.Fatal("PreferredResolutionFromContext() found no cached preferred resolution") + } + if preferred != resolution { + t.Fatalf("cached preferred resolution = %#v, want %#v", preferred, resolution) + } +} + +func TestResolver_ResolvePreferred_StaticFallback(t *testing.T) { + strictErr := &ExtractionError{Err: ErrInvalidIP, Source: SourceXRealIP} + source := &countingResolverSource{err: strictErr} + resolver := mustNewResolver(t, newResolverTestExtractor(source), ResolverConfig{ + PreferredFallback: PreferredFallbackStaticIP, + StaticFallbackIP: netip.MustParseAddr("0.0.0.0"), + }) + + req := &http.Request{RemoteAddr: "bad-remote-addr", Header: make(http.Header)} + req, resolution := resolver.ResolvePreferred(req) + + if source.calls != 1 { + t.Fatalf("extract calls = %d, want 1", source.calls) + } + if !resolution.OK() { + t.Fatalf("ResolvePreferred() error = %v", resolution.Err) + } + if !resolution.FallbackUsed { + t.Fatal("FallbackUsed = false, want true") + } + if got, want := resolution.IP, netip.MustParseAddr("0.0.0.0"); got != want { + t.Fatalf("IP = %s, want %s", got, want) + } + if got, want := resolution.Source, SourceStaticFallback; got != want { + t.Fatalf("Source = %q, want %q", got, want) + } + + strict, ok := StrictResolutionFromContext(req.Context()) + if !ok { + t.Fatal("StrictResolutionFromContext() found no cached strict resolution") + } + if !errors.Is(strict.Err, ErrInvalidIP) { + t.Fatalf("strict error = %v, want ErrInvalidIP", strict.Err) + } +} + +func TestResolver_ResolvePreferred_DoesNotFallbackOnCanceledOrDeadline(t *testing.T) { + resolver := mustNewResolver(t, newResolverTestExtractor(&countingResolverSource{ + extractFn: func(req requestView) (Extraction, error) { + return Extraction{}, req.context().Err() + }, + }), ResolverConfig{PreferredFallback: PreferredFallbackRemoteAddr}) + + tests := []struct { + name string + newRequest func() *http.Request + newInput func() Input + wantErr error + }{ + { + name: "request canceled", + newRequest: func() *http.Request { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return (&http.Request{RemoteAddr: "127.0.0.1:8080", Header: make(http.Header)}).WithContext(ctx) + }, + wantErr: context.Canceled, + }, + { + name: "input deadline exceeded", + newInput: func() Input { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer cancel() + return Input{Context: ctx, RemoteAddr: "127.0.0.1:8080"} + }, + wantErr: context.DeadlineExceeded, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.newRequest != nil { + req, resolution := resolver.ResolvePreferred(tt.newRequest()) + if !errors.Is(resolution.Err, tt.wantErr) { + t.Fatalf("error = %v, want %v", resolution.Err, tt.wantErr) + } + if resolution.FallbackUsed { + t.Fatal("FallbackUsed = true, want false") + } + if preferred, ok := PreferredResolutionFromContext(req.Context()); !ok || !errors.Is(preferred.Err, tt.wantErr) { + t.Fatalf("cached preferred = %#v, ok=%t, want error %v", preferred, ok, tt.wantErr) + } + return + } + + input, resolution := resolver.ResolveInputPreferred(tt.newInput()) + if !errors.Is(resolution.Err, tt.wantErr) { + t.Fatalf("error = %v, want %v", resolution.Err, tt.wantErr) + } + if resolution.FallbackUsed { + t.Fatal("FallbackUsed = true, want false") + } + if preferred, ok := PreferredResolutionFromContext(input.Context); !ok || !errors.Is(preferred.Err, tt.wantErr) { + t.Fatalf("cached preferred = %#v, ok=%t, want error %v", preferred, ok, tt.wantErr) + } + }) + } +} + +func TestResolver_ResolveInputStrict_CachesSuccess(t *testing.T) { + source := &countingResolverSource{ + result: Extraction{IP: netip.MustParseAddr("2001:db8::1"), Source: SourceXRealIP}, + } + resolver := mustNewResolver(t, newResolverTestExtractor(source), ResolverConfig{}) + + input := Input{Context: context.Background(), RemoteAddr: "203.0.113.10:443"} + input, first := resolver.ResolveInputStrict(input) + input, second := resolver.ResolveInputStrict(input) + + if source.calls != 1 { + t.Fatalf("extract calls = %d, want 1", source.calls) + } + if first != second { + t.Fatalf("second resolution = %#v, want %#v", second, first) + } + if cached, ok := StrictResolutionFromContext(input.Context); !ok || cached != second { + t.Fatalf("cached strict = %#v, ok=%t, want %#v", cached, ok, second) + } +} + +type resolutionView struct { + IP string + Source Source + ErrInvalidIP bool + FallbackUsed bool +} + +func viewResolution(r Resolution) resolutionView { + view := resolutionView{Source: r.Source, FallbackUsed: r.FallbackUsed} + if r.IP.IsValid() { + view.IP = r.IP.String() + } + view.ErrInvalidIP = errors.Is(r.Err, ErrInvalidIP) + return view +} + +func TestResolverState_ResolveStrict_CachesComputedValue(t *testing.T) { + var state resolverState + computeCalls := 0 + first := state.ResolveStrict(func() Resolution { + computeCalls++ + return Resolution{Extraction: Extraction{Source: SourceXRealIP}} + }) + second := state.ResolveStrict(func() Resolution { + computeCalls++ + return Resolution{Extraction: Extraction{Source: SourceRemoteAddr}} + }) + + if got, want := first.Source, SourceXRealIP; got != want { + t.Fatalf("first strict source = %q, want %q", got, want) + } + if got, want := second.Source, SourceXRealIP; got != want { + t.Fatalf("second strict source = %q, want %q", got, want) + } + if got, want := computeCalls, 1; got != want { + t.Fatalf("strict compute calls = %d, want %d", got, want) + } + + if got, ok := state.StrictValue(); !ok || got.Source != SourceXRealIP { + t.Fatalf("StrictValue() = (%#v, %t), want source %q", got, ok, SourceXRealIP) + } + if _, ok := state.PreferredValue(); ok { + t.Fatal("PreferredValue() ok = true, want false") + } +} + +func TestResolverState_ResolvePreferred(t *testing.T) { + tests := []struct { + name string + strict Resolution + shouldFallback func(Resolution) bool + fallback func(Resolution) (Resolution, bool) + wantPreferred Resolution + wantStrict Resolution + wantStrictCalls int + wantFallbackCalls int + }{ + { + name: "reuses strict result when fallback does not apply", + strict: Resolution{Extraction: Extraction{Source: SourceXRealIP}}, + shouldFallback: func(r Resolution) bool { return r.Err != nil }, + fallback: func(Resolution) (Resolution, bool) { + return Resolution{Extraction: Extraction{Source: SourceRemoteAddr}, FallbackUsed: true}, true + }, + wantPreferred: Resolution{Extraction: Extraction{Source: SourceXRealIP}}, + wantStrict: Resolution{Extraction: Extraction{Source: SourceXRealIP}}, + wantStrictCalls: 1, + wantFallbackCalls: 0, + }, + { + name: "uses fallback when allowed", + strict: Resolution{Err: ErrInvalidIP}, + shouldFallback: func(r Resolution) bool { return r.Err != nil }, + fallback: func(Resolution) (Resolution, bool) { + return Resolution{Extraction: Extraction{Source: SourceRemoteAddr}, FallbackUsed: true}, true + }, + wantPreferred: Resolution{Extraction: Extraction{Source: SourceRemoteAddr}, FallbackUsed: true}, + wantStrict: Resolution{Err: ErrInvalidIP}, + wantStrictCalls: 1, + wantFallbackCalls: 1, + }, + { + name: "keeps strict value when fallback declines", + strict: Resolution{Err: ErrInvalidIP}, + shouldFallback: func(r Resolution) bool { return r.Err != nil }, + fallback: func(Resolution) (Resolution, bool) { + return Resolution{Extraction: Extraction{Source: SourceRemoteAddr}, FallbackUsed: true}, false + }, + wantPreferred: Resolution{Err: ErrInvalidIP}, + wantStrict: Resolution{Err: ErrInvalidIP}, + wantStrictCalls: 1, + wantFallbackCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var state resolverState + strictCalls := 0 + fallbackCalls := 0 + + first := state.ResolvePreferred( + func() Resolution { + strictCalls++ + return tt.strict + }, + tt.shouldFallback, + func(r Resolution) (Resolution, bool) { + fallbackCalls++ + return tt.fallback(r) + }, + ) + second := state.ResolvePreferred( + func() Resolution { + strictCalls++ + return Resolution{Extraction: Extraction{Source: SourceStaticFallback}} + }, + tt.shouldFallback, + func(r Resolution) (Resolution, bool) { + fallbackCalls++ + return tt.fallback(r) + }, + ) + + if diff := cmp.Diff(viewResolution(tt.wantPreferred), viewResolution(first)); diff != "" { + t.Fatalf("first preferred mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(viewResolution(tt.wantPreferred), viewResolution(second)); diff != "" { + t.Fatalf("second preferred mismatch (-want +got):\n%s", diff) + } + if got, want := strictCalls, tt.wantStrictCalls; got != want { + t.Fatalf("strict compute calls = %d, want %d", got, want) + } + if got, want := fallbackCalls, tt.wantFallbackCalls; got != want { + t.Fatalf("fallback calls = %d, want %d", got, want) + } + + if diff := cmp.Diff(viewResolution(tt.wantStrict), viewResolution(mustStrictValue(t, &state))); diff != "" { + t.Fatalf("strict cache mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(viewResolution(tt.wantPreferred), viewResolution(mustPreferredValue(t, &state))); diff != "" { + t.Fatalf("preferred cache mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestResolverState_ResolveStrict_ConcurrentAccessComputesOnce(t *testing.T) { + var state resolverState + var computeCalls atomic.Int32 + + const goroutines = 16 + start := make(chan struct{}) + values := make(chan Source, goroutines) + + var wg sync.WaitGroup + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + values <- state.ResolveStrict(func() Resolution { + computeCalls.Add(1) + time.Sleep(5 * time.Millisecond) + return Resolution{Extraction: Extraction{Source: SourceXRealIP}} + }).Source + }() + } + + close(start) + wg.Wait() + close(values) + + got := make([]Source, 0, goroutines) + for value := range values { + got = append(got, value) + } + + want := make([]Source, goroutines) + for i := range want { + want[i] = SourceXRealIP + } + + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("concurrent strict values mismatch (-want +got):\n%s", diff) + } + if got, want := int(computeCalls.Load()), 1; got != want { + t.Fatalf("strict compute calls = %d, want %d", got, want) + } +} + +func mustStrictValue(t *testing.T, state *resolverState) Resolution { + t.Helper() + value, ok := state.StrictValue() + if !ok { + t.Fatal("StrictValue() ok = false, want true") + } + return value +} + +func mustPreferredValue(t *testing.T, state *resolverState) Resolution { + t.Helper() + value, ok := state.PreferredValue() + if !ok { + t.Fatal("PreferredValue() ok = false, want true") + } + return value +} diff --git a/sources.go b/source.go similarity index 51% rename from sources.go rename to source.go index 6db8c00..6c5abf5 100644 --- a/sources.go +++ b/source.go @@ -1,11 +1,8 @@ package clientip import ( - "context" "encoding/json" "errors" - "net/http" - "net/netip" "net/textproto" "strings" ) @@ -18,16 +15,23 @@ const ( sourceXForwardedFor sourceXRealIP sourceRemoteAddr + sourceStaticFallback sourceHeader ) const ( - builtinSourceNameForwarded = "forwarded" - builtinSourceNameXForwardedFor = "x_forwarded_for" - builtinSourceNameXRealIP = "x_real_ip" - builtinSourceNameRemoteAddr = "remote_addr" + builtinSourceNameForwarded = "forwarded" + builtinSourceNameXForwardedFor = "x_forwarded_for" + builtinSourceNameXRealIP = "x_real_ip" + builtinSourceNameRemoteAddr = "remote_addr" + builtinSourceNameStaticFallback = "static_fallback" ) +// Exported source identifiers for comparison and display. +// +// These are vars because Go does not support const structs. Do not reassign +// them; internal code uses builtinSource() so reassignment would only affect +// caller-side comparisons, not extraction behavior. var ( // SourceForwarded resolves from the RFC7239 Forwarded header. SourceForwarded = Source{kind: sourceForwarded} @@ -37,6 +41,8 @@ var ( SourceXRealIP = Source{kind: sourceXRealIP} // SourceRemoteAddr resolves from Request.RemoteAddr. SourceRemoteAddr = Source{kind: sourceRemoteAddr} + // SourceStaticFallback identifies resolver-only static fallback output. + SourceStaticFallback = Source{kind: sourceStaticFallback} ) // Source identifies one extraction source in priority order. @@ -59,7 +65,7 @@ func HeaderSource(name string) Source { func canonicalSource(source Source) Source { switch source.kind { - case sourceForwarded, sourceXForwardedFor, sourceXRealIP, sourceRemoteAddr: + case sourceForwarded, sourceXForwardedFor, sourceXRealIP, sourceRemoteAddr, sourceStaticFallback: return source case sourceHeader: return sourceFromString(source.headerName) @@ -69,6 +75,12 @@ func canonicalSource(source Source) Source { } func sourceFromString(name string) Source { + // Fast path: check exact matches before trimming/normalizing. + // Internal round-trips always use already-normalized names without whitespace. + if s, ok := sourceFromExact(name); ok { + return s + } + raw := strings.TrimSpace(name) if raw == "" { return Source{} @@ -83,11 +95,30 @@ func sourceFromString(name string) Source { return builtinSource(sourceXRealIP) case builtinSourceNameRemoteAddr: return builtinSource(sourceRemoteAddr) + case builtinSourceNameStaticFallback: + return builtinSource(sourceStaticFallback) default: return Source{kind: sourceHeader, headerName: textproto.CanonicalMIMEHeaderKey(raw)} } } +func sourceFromExact(name string) (Source, bool) { + switch name { + case builtinSourceNameForwarded, "Forwarded": + return builtinSource(sourceForwarded), true + case builtinSourceNameXForwardedFor, "X-Forwarded-For": + return builtinSource(sourceXForwardedFor), true + case builtinSourceNameXRealIP, "X-Real-Ip", "X-Real-IP": + return builtinSource(sourceXRealIP), true + case builtinSourceNameRemoteAddr: + return builtinSource(sourceRemoteAddr), true + case builtinSourceNameStaticFallback: + return builtinSource(sourceStaticFallback), true + default: + return Source{}, false + } +} + // canonicalizeSources ensures every source is in canonical form. // // Sources stored in config.sourcePriority are always canonical; callers must @@ -119,6 +150,8 @@ func (s Source) name() string { return builtinSourceNameXRealIP case sourceRemoteAddr: return builtinSourceNameRemoteAddr + case sourceStaticFallback: + return builtinSourceNameStaticFallback case sourceHeader: return normalizeSourceName(s.headerName) default: @@ -134,7 +167,8 @@ func (s Source) valid() bool { return s.kind == sourceForwarded || s.kind == sourceXForwardedFor || s.kind == sourceXRealIP || - s.kind == sourceRemoteAddr + s.kind == sourceRemoteAddr || + s.kind == sourceStaticFallback } func (s Source) headerKey() (string, bool) { @@ -145,7 +179,7 @@ func (s Source) headerKey() (string, bool) { return "X-Forwarded-For", true case sourceXRealIP: return "X-Real-IP", true - case sourceRemoteAddr, sourceInvalid: + case sourceRemoteAddr, sourceStaticFallback, sourceInvalid: return "", false default: return s.headerName, true @@ -198,156 +232,41 @@ func (s *Source) UnmarshalJSON(data []byte) error { return nil } -type extractionResult struct { - IP netip.Addr - TrustedProxyCount int - DebugInfo *ChainDebugInfo - Source Source -} - -type sourceExtractor interface { - Extract(ctx context.Context, r *http.Request) (extractionResult, error) - - Name() string - Source() Source -} - -func requestPath(r *http.Request) string { - if r == nil || r.URL == nil { - return "" - } - - return r.URL.Path -} - -func (e *Extractor) logSecurityWarning(ctx context.Context, r *http.Request, source Source, event, msg string, attrs ...any) { - remoteAddr := "" - if r != nil { - remoteAddr = r.RemoteAddr - } - - baseAttrs := []any{ - "event", event, - "source", source.String(), - "path", requestPath(r), - "remote_addr", remoteAddr, - } - - baseAttrs = append(baseAttrs, attrs...) - e.config.logger.WarnContext(ctx, msg, baseAttrs...) -} - -func proxyValidationWarningDetails(err error) (event, msg string, ok bool) { - switch { - case errors.Is(err, ErrNoTrustedProxies): - return securityEventNoTrustedProxies, "no trusted proxies found in request chain", true - case errors.Is(err, ErrTooFewTrustedProxies): - return securityEventTooFewTrustedProxies, "trusted proxy count below configured minimum", true - case errors.Is(err, ErrTooManyTrustedProxies): - return securityEventTooManyTrustedProxies, "trusted proxy count exceeds configured maximum", true - default: - return "", "", false - } +func normalizeSourceName(headerName string) string { + return strings.ToLower(strings.ReplaceAll(headerName, "-", "_")) } -func (e *Extractor) logProxyValidationWarning(ctx context.Context, r *http.Request, source Source, err error) { - event, msg, ok := proxyValidationWarningDetails(err) - if !ok { - return - } - - var proxyErr *ProxyValidationError - if errors.As(err, &proxyErr) { - e.logSecurityWarning(ctx, r, source, event, msg, - "trusted_proxy_count", proxyErr.TrustedProxyCount, - "min_trusted_proxies", proxyErr.MinTrustedProxies, - "max_trusted_proxies", proxyErr.MaxTrustedProxies, - ) - return - } - - e.logSecurityWarning(ctx, r, source, event, msg) -} +func sourceHeaderKeys(sourcePriority []Source) []string { + keys := make([]string, 0, len(sourcePriority)) + seen := make(map[string]struct{}, len(sourcePriority)) -func (e *Extractor) extractChainSource( - ctx context.Context, - r *http.Request, - source Source, - headerValues []string, - chainForUntrusted func() string, - untrustedProxyMessage string, - chainTooLongMessage string, - parseValues func([]string) ([]string, error), - handleParseError func(error), -) (extractionResult, error) { - if len(e.config.trustedProxyCIDRs) > 0 { - remoteAddr := "" - if r != nil { - remoteAddr = r.RemoteAddr + for _, source := range sourcePriority { + key, ok := sourceHeaderKey(source) + if !ok { + continue } - remoteIP := parseRemoteAddr(remoteAddr) - if !e.isTrustedProxy(remoteIP) { - chain := "" - if chainForUntrusted != nil { - chain = chainForUntrusted() - } - - e.config.metrics.RecordSecurityEvent(securityEventUntrustedProxy) - e.logSecurityWarning(ctx, r, source, securityEventUntrustedProxy, untrustedProxyMessage) - e.config.metrics.RecordExtractionFailure(source.String()) - return extractionResult{}, &ProxyValidationError{ - ExtractionError: ExtractionError{ - Err: ErrUntrustedProxy, - Source: source, - }, - Chain: chain, - TrustedProxyCount: 0, - MinTrustedProxies: e.config.minTrustedProxies, - MaxTrustedProxies: e.config.maxTrustedProxies, - } - } - } - - parts, err := parseValues(headerValues) - if err != nil { - if errors.Is(err, ErrChainTooLong) { - var chainErr *ChainTooLongError - if errors.As(err, &chainErr) { - e.logSecurityWarning(ctx, r, source, securityEventChainTooLong, chainTooLongMessage, - "chain_length", chainErr.ChainLength, - "max_length", chainErr.MaxLength, - ) - } else { - e.logSecurityWarning(ctx, r, source, securityEventChainTooLong, chainTooLongMessage) - } - } - - if handleParseError != nil { - handleParseError(err) + if _, duplicate := seen[key]; duplicate { + continue } - e.config.metrics.RecordExtractionFailure(source.String()) - return extractionResult{}, err + seen[key] = struct{}{} + keys = append(keys, key) } - ip, trustedCount, debugInfo, err := e.clientIPFromChainWithDebug(source, parts) - if err != nil { - e.logProxyValidationWarning(ctx, r, source, err) - e.config.metrics.RecordExtractionFailure(source.String()) - return extractionResult{}, err - } + return keys +} - e.config.metrics.RecordExtractionSuccess(source.String()) - result := extractionResult{ - IP: ip, - TrustedProxyCount: trustedCount, - Source: source, +func sourceHeaderKey(source Source) (string, bool) { + source = canonicalSource(source) + if !source.valid() { + return "", false } - if e.config.debugMode { - result.DebugInfo = debugInfo + key, ok := source.headerKey() + if !ok { + return "", false } - return result, nil + return key, true } diff --git a/source_build_test.go b/source_build_test.go new file mode 100644 index 0000000..81e7147 --- /dev/null +++ b/source_build_test.go @@ -0,0 +1,69 @@ +package clientip + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestCompileSpecFromSource(t *testing.T) { + tests := []struct { + name string + source Source + want sourceSpec + }{ + { + name: "forwarded source", + source: SourceForwarded, + want: sourceSpec{ + kind: sourceExtractorKindForwarded, + source: SourceForwarded, + headerName: "Forwarded", + }, + }, + { + name: "x forwarded for source", + source: SourceXForwardedFor, + want: sourceSpec{ + kind: sourceExtractorKindXForwardedFor, + source: SourceXForwardedFor, + headerName: "X-Forwarded-For", + }, + }, + { + name: "x real ip source", + source: SourceXRealIP, + want: sourceSpec{ + kind: sourceExtractorKindSingleHeader, + source: SourceXRealIP, + headerName: "X-Real-Ip", + }, + }, + { + name: "remote addr source", + source: SourceRemoteAddr, + want: sourceSpec{ + kind: sourceExtractorKindRemoteAddr, + source: SourceRemoteAddr, + headerName: "", + }, + }, + { + name: "custom header source", + source: HeaderSource("x-custom-header"), + want: sourceSpec{ + kind: sourceExtractorKindSingleHeader, + source: HeaderSource("X-Custom-Header"), + headerName: "X-Custom-Header", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if diff := cmp.Diff(tt.want, compileSpecFromSource(tt.source), cmp.AllowUnexported(sourceSpec{}, Source{})); diff != "" { + t.Fatalf("compileSpecFromSource() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/source_chain.go b/source_chain.go deleted file mode 100644 index 59d9c95..0000000 --- a/source_chain.go +++ /dev/null @@ -1,121 +0,0 @@ -package clientip - -import ( - "context" - "errors" - "net/http" - "net/textproto" - "strings" -) - -type chainedSource struct { - extractor *Extractor - sources []sourceExtractor - name string -} - -func newChainedSource(extractor *Extractor, sources ...sourceExtractor) *chainedSource { - names := make([]string, len(sources)) - for i, s := range sources { - names[i] = s.Name() - } - return &chainedSource{ - extractor: extractor, - sources: sources, - name: "chained[" + strings.Join(names, ",") + "]", - } -} - -func newForwardedForSource(extractor *Extractor) sourceExtractor { - sourceName := builtinSource(sourceXForwardedFor) - return &forwardedForSource{ - extractor: extractor, - sourceName: sourceName, - unavailableErr: &ExtractionError{Err: ErrSourceUnavailable, Source: sourceName}, - } -} - -func newForwardedSource(extractor *Extractor) sourceExtractor { - sourceName := builtinSource(sourceForwarded) - return &forwardedSource{ - extractor: extractor, - sourceName: sourceName, - unavailableErr: &ExtractionError{Err: ErrSourceUnavailable, Source: sourceName}, - } -} - -func newSingleHeaderSource(extractor *Extractor, headerName string) sourceExtractor { - sourceName := HeaderSource(headerName) - return &singleHeaderSource{ - extractor: extractor, - headerName: headerName, - headerKey: textproto.CanonicalMIMEHeaderKey(headerName), - sourceName: sourceName, - unavailableErr: &ExtractionError{Err: ErrSourceUnavailable, Source: sourceName}, - } -} - -func newRemoteAddrSource(extractor *Extractor) sourceExtractor { - sourceName := builtinSource(sourceRemoteAddr) - return &remoteAddrSource{ - extractor: extractor, - sourceName: sourceName, - unavailableErr: &ExtractionError{Err: ErrSourceUnavailable, Source: sourceName}, - } -} - -func (c *chainedSource) Extract(ctx context.Context, r *http.Request) (extractionResult, error) { - var lastErr error - for _, source := range c.sources { - // Check if context has been cancelled before attempting next source - if ctx.Err() != nil { - return extractionResult{}, ctx.Err() - } - - result, err := source.Extract(ctx, r) - if err == nil { - if !result.Source.valid() { - result.Source = source.Source() - } - return result, nil - } - - if c.isTerminalError(err) { - return extractionResult{}, err - } - - lastErr = err - } - return extractionResult{}, lastErr -} - -func (c *chainedSource) isTerminalError(err error) bool { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return true - } - - if errors.Is(err, ErrSourceUnavailable) { - return false - } - - if c.extractor.config.securityMode == SecurityModeLax { - return false - } - - return errors.Is(err, ErrInvalidIP) || - errors.Is(err, ErrMultipleSingleIPHeaders) || - errors.Is(err, ErrUntrustedProxy) || - errors.Is(err, ErrNoTrustedProxies) || - errors.Is(err, ErrTooFewTrustedProxies) || - errors.Is(err, ErrTooManyTrustedProxies) || - errors.Is(err, ErrChainTooLong) || - errors.Is(err, ErrInvalidForwardedHeader) -} - -func (c *chainedSource) Name() string { - return c.name -} - -func (c *chainedSource) Source() Source { - return Source{} -} diff --git a/source_chain_extract.go b/source_chain_extract.go new file mode 100644 index 0000000..c196569 --- /dev/null +++ b/source_chain_extract.go @@ -0,0 +1,106 @@ +package clientip + +import ( + "net/netip" + "strings" +) + +type chainPolicy struct { + headerName string + parseValues func([]string) ([]string, error) + clientIP clientIPPolicy + trustedProxy proxyPolicy + selection ChainSelection + collectDebugInfo bool + untrustedChainSep string +} + +type chainExtractor struct { + policy chainPolicy +} + +func (e chainExtractor) extract(req requestView, source Source) (Extraction, *extractionFailure, error) { + headerValues := req.valuesCanonical(e.policy.headerName) + if len(headerValues) == 0 { + return Extraction{}, errSourceUnavailable, nil + } + + if len(e.policy.trustedProxy.TrustedProxyCIDRs) > 0 { + remoteIP := parseRemoteAddr(req.remoteAddr()) + if !isTrustedProxy(remoteIP, e.policy.trustedProxy.TrustedProxyMatch, e.policy.trustedProxy.TrustedProxyCIDRs) { + return Extraction{}, &extractionFailure{ + kind: failureUntrustedProxy, + source: source, + chain: strings.Join(headerValues, e.chainSeparator()), + trustedProxyCount: 0, + minTrustedProxies: e.policy.trustedProxy.MinTrustedProxies, + maxTrustedProxies: e.policy.trustedProxy.MaxTrustedProxies, + }, nil + } + } + + parts, err := e.policy.parseValues(headerValues) + if err != nil { + return Extraction{}, nil, err + } + if len(parts) == 0 { + return Extraction{}, &extractionFailure{kind: failureEmptyChain, source: source}, nil + } + + analysis, clientIP, err := e.analyzeChain(parts) + if err != nil { + return Extraction{}, &extractionFailure{ + kind: failureProxyValidation, + source: source, + chain: strings.Join(parts, ", "), + trustedProxyCount: analysis.TrustedCount, + minTrustedProxies: e.policy.trustedProxy.MinTrustedProxies, + maxTrustedProxies: e.policy.trustedProxy.MaxTrustedProxies, + }, nil + } + + clientIPStr := parts[analysis.ClientIndex] + disposition := evaluateClientIP(clientIP, e.policy.clientIP) + if disposition != clientIPValid { + return Extraction{}, &extractionFailure{ + kind: failureInvalidClientIP, + source: source, + chain: strings.Join(parts, ", "), + index: analysis.ClientIndex, + extractedIP: clientIPStr, + trustedProxyCount: analysis.TrustedCount, + clientIPDisposition: disposition, + }, nil + } + + result := Extraction{ + IP: normalizeIP(clientIP), + TrustedProxyCount: analysis.TrustedCount, + Source: source, + } + if e.policy.collectDebugInfo { + result.DebugInfo = &ChainDebugInfo{ + FullChain: parts, + ClientIndex: analysis.ClientIndex, + TrustedIndices: analysis.TrustedIndices, + } + } + + return result, nil, nil +} + +func (e chainExtractor) analyzeChain(parts []string) (chainAnalysis, netip.Addr, error) { + if e.policy.selection == LeftmostUntrustedIP { + return analyzeChainLeftmost(parts, e.policy.trustedProxy, e.policy.collectDebugInfo) + } + + return analyzeChainRightmost(parts, e.policy.trustedProxy, e.policy.collectDebugInfo) +} + +func (e chainExtractor) chainSeparator() string { + if e.policy.untrustedChainSep != "" { + return e.policy.untrustedChainSep + } + + return ", " +} diff --git a/source_chain_extract_test.go b/source_chain_extract_test.go new file mode 100644 index 0000000..327d2af --- /dev/null +++ b/source_chain_extract_test.go @@ -0,0 +1,401 @@ +package clientip + +import ( + "errors" + "net/netip" + "strings" + "testing" +) + +// simpleXFFParse is a minimal comma-split parser for tests. +func simpleXFFParse(values []string) ([]string, error) { + var parts []string + for _, v := range values { + for _, seg := range strings.Split(v, ",") { + seg = strings.TrimSpace(seg) + if seg != "" { + parts = append(parts, seg) + } + } + } + return parts, nil +} + +func TestChainExtractor_HeaderMissing(t *testing.T) { + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + selection: RightmostUntrustedIP, + }} + + req := requestView{ + headerMap: map[string][]string{}, + } + + _, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure != errSourceUnavailable { + t.Errorf("failure = %+v, want errSourceUnavailable", failure) + } +} + +func TestChainExtractor_SingleValidValue(t *testing.T) { + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + selection: RightmostUntrustedIP, + }} + + req := requestView{ + headerMap: map[string][]string{ + "X-Forwarded-For": {"8.8.8.8"}, + }, + } + + result, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + wantIP := netip.MustParseAddr("8.8.8.8") + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } + if result.Source != SourceXForwardedFor { + t.Errorf("Source = %v, want %v", result.Source, SourceXForwardedFor) + } +} + +func TestChainExtractor_ChainWithTrustedProxies(t *testing.T) { + trustedCIDR := netip.MustParsePrefix("10.0.0.0/8") + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + trustedProxy: proxyPolicy{ + TrustedProxyCIDRs: []netip.Prefix{trustedCIDR}, + TrustedProxyMatch: newPrefixMatcher([]netip.Prefix{trustedCIDR}), + }, + selection: RightmostUntrustedIP, + }} + + // Chain: client, proxy1, proxy2 + // 10.0.0.x are trusted, so client IP should be 8.8.8.8 + req := requestView{ + remoteAddrValue: "10.0.0.3:8080", + headerMap: map[string][]string{ + "X-Forwarded-For": {"8.8.8.8, 10.0.0.1, 10.0.0.2"}, + }, + } + + result, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + wantIP := netip.MustParseAddr("8.8.8.8") + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } + if result.TrustedProxyCount != 2 { + t.Errorf("TrustedProxyCount = %d, want 2", result.TrustedProxyCount) + } +} + +func TestChainExtractor_EmptyChainAfterParse(t *testing.T) { + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: func(values []string) ([]string, error) { + return nil, nil // empty result + }, + selection: RightmostUntrustedIP, + }} + + req := requestView{ + headerMap: map[string][]string{ + "X-Forwarded-For": {" "}, + }, + } + + _, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure.kind != failureEmptyChain { + t.Errorf("failure.kind = %v, want failureEmptyChain", failure.kind) + } +} + +func TestChainExtractor_UntrustedProxy(t *testing.T) { + trustedCIDR := netip.MustParsePrefix("10.0.0.0/8") + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + trustedProxy: proxyPolicy{ + TrustedProxyCIDRs: []netip.Prefix{trustedCIDR}, + TrustedProxyMatch: newPrefixMatcher([]netip.Prefix{trustedCIDR}), + }, + selection: RightmostUntrustedIP, + }} + + // Remote addr is NOT trusted. + req := requestView{ + remoteAddrValue: "5.5.5.5:4567", + headerMap: map[string][]string{ + "X-Forwarded-For": {"9.9.9.9"}, + }, + } + + _, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure.kind != failureUntrustedProxy { + t.Errorf("failure.kind = %v, want failureUntrustedProxy", failure.kind) + } +} + +func TestChainExtractor_DebugInfoCollected(t *testing.T) { + trustedCIDR := netip.MustParsePrefix("10.0.0.0/8") + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + trustedProxy: proxyPolicy{ + TrustedProxyCIDRs: []netip.Prefix{trustedCIDR}, + TrustedProxyMatch: newPrefixMatcher([]netip.Prefix{trustedCIDR}), + }, + selection: RightmostUntrustedIP, + collectDebugInfo: true, + }} + + req := requestView{ + remoteAddrValue: "10.0.0.3:8080", + headerMap: map[string][]string{ + "X-Forwarded-For": {"8.8.8.8, 10.0.0.1, 10.0.0.2"}, + }, + } + + result, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + if result.DebugInfo == nil { + t.Fatal("expected DebugInfo to be set") + } + if len(result.DebugInfo.FullChain) != 3 { + t.Errorf("FullChain length = %d, want 3", len(result.DebugInfo.FullChain)) + } + if result.DebugInfo.ClientIndex != 0 { + t.Errorf("ClientIndex = %d, want 0", result.DebugInfo.ClientIndex) + } + if len(result.DebugInfo.TrustedIndices) != 2 { + t.Errorf("TrustedIndices length = %d, want 2", len(result.DebugInfo.TrustedIndices)) + } +} + +func TestChainExtractor_DebugInfoNotCollectedByDefault(t *testing.T) { + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + selection: RightmostUntrustedIP, + collectDebugInfo: false, + }} + + req := requestView{ + headerMap: map[string][]string{ + "X-Forwarded-For": {"8.8.8.8"}, + }, + } + + result, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + if result.DebugInfo != nil { + t.Error("expected DebugInfo to be nil when collectDebugInfo is false") + } +} + +func TestChainExtractor_InvalidClientIP(t *testing.T) { + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + selection: RightmostUntrustedIP, + }} + + req := requestView{ + headerMap: map[string][]string{ + "X-Forwarded-For": {"not-valid-ip"}, + }, + } + + _, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure.kind != failureInvalidClientIP { + t.Errorf("failure.kind = %v, want failureInvalidClientIP", failure.kind) + } + if failure.extractedIP != "not-valid-ip" { + t.Errorf("failure.extractedIP = %q, want %q", failure.extractedIP, "not-valid-ip") + } +} + +func TestChainExtractor_ParseValuesError(t *testing.T) { + parseErr := &chainTooLongParseError{ChainLength: 100, MaxLength: 50} + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: func(values []string) ([]string, error) { + return nil, parseErr + }, + selection: RightmostUntrustedIP, + }} + + req := requestView{ + headerMap: map[string][]string{ + "X-Forwarded-For": {"a, b, c"}, + }, + } + + _, failure, err := ext.extract(req, SourceXForwardedFor) + if err == nil { + t.Fatal("expected error from parseValues, got nil") + } + if failure != nil { + t.Errorf("failure should be nil when parseValues returns error, got %+v", failure) + } + if !errors.Is(err, parseErr) { + t.Errorf("error = %v, want %v", err, parseErr) + } +} + +func TestChainExtractor_MultipleHeaderValues(t *testing.T) { + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + selection: RightmostUntrustedIP, + }} + + // Multiple X-Forwarded-For header values (as separate entries). + // simpleXFFParse treats them as two separate chain parts. + req := requestView{ + headerMap: map[string][]string{ + "X-Forwarded-For": {"8.8.8.8", "9.9.9.9"}, + }, + } + + result, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + // With no trusted CIDRs and no MaxTrustedProxies, every entry is walked + // as trusted. The leftmost entry (index 0) becomes the client IP. + wantIP := netip.MustParseAddr("8.8.8.8") + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } +} + +func TestChainExtractor_LoopbackClientIPRejected(t *testing.T) { + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + selection: RightmostUntrustedIP, + }} + + req := requestView{ + headerMap: map[string][]string{ + "X-Forwarded-For": {"127.0.0.1"}, + }, + } + + _, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure == nil { + t.Fatal("expected failure for loopback client IP") + } + if failure.kind != failureInvalidClientIP { + t.Errorf("failure.kind = %v, want failureInvalidClientIP", failure.kind) + } +} + +func TestChainExtractor_NoHeadersAtAll(t *testing.T) { + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + selection: RightmostUntrustedIP, + }} + + req := requestView{} + + _, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure != errSourceUnavailable { + t.Errorf("failure = %+v, want errSourceUnavailable", failure) + } +} + +func TestChainExtractor_IPv6InChain(t *testing.T) { + trustedCIDR := netip.MustParsePrefix("fd00::/8") + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + trustedProxy: proxyPolicy{ + TrustedProxyCIDRs: []netip.Prefix{trustedCIDR}, + TrustedProxyMatch: newPrefixMatcher([]netip.Prefix{trustedCIDR}), + }, + selection: RightmostUntrustedIP, + }} + + req := requestView{ + remoteAddrValue: "[fd00::3]:8080", + headerMap: map[string][]string{ + "X-Forwarded-For": {"2606:4700::1, fd00::1, fd00::2"}, + }, + } + + result, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + wantIP := netip.MustParseAddr("2606:4700::1") + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } +} diff --git a/source_chain_test.go b/source_chain_test.go deleted file mode 100644 index f4b1682..0000000 --- a/source_chain_test.go +++ /dev/null @@ -1,431 +0,0 @@ -package clientip - -import ( - "context" - "errors" - "net/http" - "net/netip" - "net/textproto" - "testing" - - "github.com/google/go-cmp/cmp" -) - -type stubSourceExtractor struct { - name string - source Source - result extractionResult - err error - calls *int -} - -func (s *stubSourceExtractor) Extract(context.Context, *http.Request) (extractionResult, error) { - if s.calls != nil { - *s.calls = *s.calls + 1 - } - - return s.result, s.err -} - -func (s *stubSourceExtractor) Name() string { - return s.name -} - -func (s *stubSourceExtractor) Source() Source { - if s.source.valid() { - return s.source - } - - return HeaderSource(s.name) -} - -func TestChainedSource_Extract(t *testing.T) { - extractor := mustNewExtractor(t) - - tests := []struct { - name string - sources []sourceExtractor - remoteAddr string - xff string - xRealIP string - wantValid bool - wantIP string - wantSource Source - }{ - { - name: "first source succeeds", - sources: []sourceExtractor{ - &forwardedForSource{extractor: extractor}, - &remoteAddrSource{extractor: extractor}, - }, - remoteAddr: "127.0.0.1:8080", - xff: "1.1.1.1", - wantValid: true, - wantIP: "1.1.1.1", - wantSource: SourceXForwardedFor, - }, - { - name: "fallback to second source", - sources: []sourceExtractor{ - &forwardedForSource{extractor: extractor}, - &remoteAddrSource{extractor: extractor}, - }, - remoteAddr: "1.1.1.1:8080", - xff: "", - wantValid: true, - wantIP: "1.1.1.1", - wantSource: SourceRemoteAddr, - }, - { - name: "all sources fail", - sources: []sourceExtractor{ - &forwardedForSource{extractor: extractor}, - &remoteAddrSource{extractor: extractor}, - }, - remoteAddr: "127.0.0.1:8080", - xff: "", - wantValid: false, - }, - { - name: "custom priority order", - sources: []sourceExtractor{ - &singleHeaderSource{ - extractor: extractor, - headerName: "X-Real-IP", - headerKey: textproto.CanonicalMIMEHeaderKey("X-Real-IP"), - sourceName: SourceXRealIP, - }, - &forwardedForSource{extractor: extractor}, - &remoteAddrSource{extractor: extractor}, - }, - remoteAddr: "127.0.0.1:8080", - xff: "8.8.8.8", - xRealIP: "1.1.1.1", - wantValid: true, - wantIP: "1.1.1.1", - wantSource: SourceXRealIP, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - chained := newChainedSource(extractor, tt.sources...) - - req := &http.Request{ - RemoteAddr: tt.remoteAddr, - Header: make(http.Header), - } - if tt.xff != "" { - req.Header.Set("X-Forwarded-For", tt.xff) - } - if tt.xRealIP != "" { - req.Header.Set("X-Real-IP", tt.xRealIP) - } - - result, err := chained.Extract(context.Background(), req) - - if tt.wantValid { - if err != nil { - t.Errorf("Extract() error = %v, want nil", err) - } - want := netip.MustParseAddr(tt.wantIP) - if result.IP != want { - t.Errorf("Extract() IP = %v, want %v", result.IP, want) - } - - if result.Source != tt.wantSource { - t.Errorf("result.Source = %q, want %q", result.Source, tt.wantSource) - } - } else { - if err == nil { - t.Errorf("Extract() error = nil, want non-nil") - } - } - }) - } -} - -func TestChainedSource_ContextCanceledPrecheck(t *testing.T) { - extractor, err := New() - if err != nil { - t.Fatalf("New() error = %v", err) - } - - tests := []struct { - name string - ctx func() context.Context - want struct { - Calls int - ErrCanceled bool - IP string - Source Source - } - }{ - { - name: "already canceled context does not call sources", - ctx: func() context.Context { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - return ctx - }, - want: struct { - Calls int - ErrCanceled bool - IP string - Source Source - }{Calls: 0, ErrCanceled: true}, - }, - { - name: "active context calls first source", - ctx: func() context.Context { - return context.Background() - }, - want: struct { - Calls int - ErrCanceled bool - IP string - Source Source - }{Calls: 1, ErrCanceled: false, IP: "1.1.1.1", Source: HeaderSource("stub_source")}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - calls := 0 - source := &stubSourceExtractor{ - name: "stub_source", - source: HeaderSource("stub_source"), - result: extractionResult{IP: netip.MustParseAddr("1.1.1.1"), Source: HeaderSource("stub_source")}, - calls: &calls, - } - chained := newChainedSource(extractor, source) - - result, extractErr := chained.Extract(tt.ctx(), &http.Request{Header: make(http.Header)}) - - got := struct { - Calls int - ErrCanceled bool - IP string - Source Source - }{ - Calls: calls, - ErrCanceled: errors.Is(extractErr, context.Canceled), - Source: result.Source, - } - if result.IP.IsValid() { - got.IP = result.IP.String() - } - - if diff := cmp.Diff(tt.want, got); diff != "" { - t.Fatalf("chained result mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestChainedSource_FillsEmptySourceNameFromSource(t *testing.T) { - extractor, err := New() - if err != nil { - t.Fatalf("New() error = %v", err) - } - - calls := 0 - source := &stubSourceExtractor{ - name: "custom_source", - result: extractionResult{IP: netip.MustParseAddr("1.1.1.1")}, - calls: &calls, - } - chained := newChainedSource(extractor, source) - - got, extractErr := chained.Extract(context.Background(), &http.Request{Header: make(http.Header)}) - if extractErr != nil { - t.Fatalf("Extract() error = %v", extractErr) - } - - gotView := struct { - Calls int - IP string - Source Source - }{ - Calls: calls, - IP: got.IP.String(), - Source: got.Source, - } - wantView := struct { - Calls int - IP string - Source Source - }{ - Calls: 1, - IP: "1.1.1.1", - Source: HeaderSource("custom_source"), - } - - if diff := cmp.Diff(wantView, gotView); diff != "" { - t.Fatalf("extraction mismatch (-want +got):\n%s", diff) - } -} - -func TestChainedSource_Name(t *testing.T) { - extractor := mustNewExtractor(t) - sources := []sourceExtractor{ - &forwardedForSource{extractor: extractor}, - &singleHeaderSource{ - extractor: extractor, - headerName: "X-Real-IP", - headerKey: textproto.CanonicalMIMEHeaderKey("X-Real-IP"), - sourceName: SourceXRealIP, - }, - &remoteAddrSource{extractor: extractor}, - } - - chained := newChainedSource(extractor, sources...) - name := chained.Name() - - expectedName := "chained[x_forwarded_for,x_real_ip,remote_addr]" - if name != expectedName { - t.Errorf("Name() = %q, want %q", name, expectedName) - } -} - -func TestSourceFactories(t *testing.T) { - extractor := mustNewExtractor(t) - - t.Run("Forwarded source", func(t *testing.T) { - source := newForwardedSource(extractor) - if source.Source() != SourceForwarded { - t.Errorf("newForwardedSource() source = %q, want %q", source.Source(), SourceForwarded) - } - }) - - t.Run("XForwardedFor source", func(t *testing.T) { - source := newForwardedForSource(extractor) - if source.Source() != SourceXForwardedFor { - t.Errorf("newForwardedForSource() source = %q, want %q", source.Source(), SourceXForwardedFor) - } - }) - - t.Run("XRealIP source", func(t *testing.T) { - source := newSingleHeaderSource(extractor, "X-Real-IP") - if source.Source() != SourceXRealIP { - t.Errorf("newSingleHeaderSource(X-Real-IP) source = %q, want %q", source.Source(), SourceXRealIP) - } - - single, ok := source.(*singleHeaderSource) - if !ok { - t.Fatalf("newSingleHeaderSource() type = %T, want *singleHeaderSource", source) - } - - wantHeaderKey := textproto.CanonicalMIMEHeaderKey("X-Real-IP") - if single.headerKey != wantHeaderKey { - t.Errorf("newSingleHeaderSource(X-Real-IP) headerKey = %q, want %q", single.headerKey, wantHeaderKey) - } - }) - - t.Run("RemoteAddr source", func(t *testing.T) { - source := newRemoteAddrSource(extractor) - if source.Source() != SourceRemoteAddr { - t.Errorf("newRemoteAddrSource() source = %q, want %q", source.Source(), SourceRemoteAddr) - } - }) - - t.Run("Custom header source", func(t *testing.T) { - source := newSingleHeaderSource(extractor, "X-Custom-Header") - if got, want := source.Source(), HeaderSource("X-Custom-Header"); got != want { - t.Errorf("newSingleHeaderSource() source = %q, want %q", got, want) - } - }) -} - -func TestHeaderSource_String(t *testing.T) { - tests := []struct { - input string - want string - }{ - { - input: "X-Forwarded-For", - want: "x_forwarded_for", - }, - { - input: "Forwarded", - want: "forwarded", - }, - { - input: "X-Real-IP", - want: "x_real_ip", - }, - { - input: "CF-Connecting-IP", - want: "cf_connecting_ip", - }, - { - input: "UPPERCASE-HEADER", - want: "uppercase_header", - }, - { - input: "already_underscored", - want: "already_underscored", - }, - } - - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - got := HeaderSource(tt.input).String() - if got != tt.want { - t.Errorf("HeaderSource(%q).String() = %q, want %q", tt.input, got, tt.want) - } - }) - } -} - -func TestSourceUnavailableErrors(t *testing.T) { - extractor, err := New() - if err != nil { - t.Fatalf("New() error = %v", err) - } - - t.Run("forwarded_for_missing", func(t *testing.T) { - source := &forwardedForSource{extractor: extractor} - req := &http.Request{Header: make(http.Header)} - - _, extractErr := source.Extract(context.Background(), req) - if !errors.Is(extractErr, ErrSourceUnavailable) { - t.Fatalf("error = %v, want ErrSourceUnavailable", extractErr) - } - }) - - t.Run("forwarded_missing", func(t *testing.T) { - source := &forwardedSource{extractor: extractor} - req := &http.Request{Header: make(http.Header)} - - _, extractErr := source.Extract(context.Background(), req) - if !errors.Is(extractErr, ErrSourceUnavailable) { - t.Fatalf("error = %v, want ErrSourceUnavailable", extractErr) - } - }) - - t.Run("single_header_missing", func(t *testing.T) { - source := &singleHeaderSource{ - extractor: extractor, - headerName: "X-Real-IP", - headerKey: textproto.CanonicalMIMEHeaderKey("X-Real-IP"), - sourceName: SourceXRealIP, - } - req := &http.Request{Header: make(http.Header)} - - _, extractErr := source.Extract(context.Background(), req) - if !errors.Is(extractErr, ErrSourceUnavailable) { - t.Fatalf("error = %v, want ErrSourceUnavailable", extractErr) - } - }) - - t.Run("remote_addr_missing", func(t *testing.T) { - source := &remoteAddrSource{extractor: extractor} - req := &http.Request{Header: make(http.Header)} - - _, extractErr := source.Extract(context.Background(), req) - if !errors.Is(extractErr, ErrSourceUnavailable) { - t.Fatalf("error = %v, want ErrSourceUnavailable", extractErr) - } - }) -} diff --git a/source_chained.go b/source_chained.go new file mode 100644 index 0000000..fb39b49 --- /dev/null +++ b/source_chained.go @@ -0,0 +1,56 @@ +package clientip + +import "strings" + +type chainedSource struct { + sources []sourceExtractor + sourceName string + isTerminal func(error) bool +} + +func newChainedSource(isTerminal func(error) bool, sources ...sourceExtractor) *chainedSource { + names := make([]string, len(sources)) + for i, s := range sources { + names[i] = s.name() + } + + return &chainedSource{ + sources: sources, + sourceName: "chained[" + strings.Join(names, ",") + "]", + isTerminal: isTerminal, + } +} + +func (c *chainedSource) extract(r requestView) (Extraction, error) { + var lastErr error + for i, source := range c.sources { + // Context is already checked by extractWithSource before the first + // source; only re-check between subsequent sources in the chain. + if i > 0 { + if err := r.context().Err(); err != nil { + return Extraction{}, err + } + } + + result, err := source.extract(r) + if err == nil { + return result, nil + } + + if c.isTerminal != nil && c.isTerminal(err) { + return Extraction{}, err + } + + lastErr = err + } + + return Extraction{}, lastErr +} + +func (c *chainedSource) name() string { + return c.sourceName +} + +func (c *chainedSource) sourceInfo() Source { + return Source{} +} diff --git a/source_chained_test.go b/source_chained_test.go new file mode 100644 index 0000000..f19dc55 --- /dev/null +++ b/source_chained_test.go @@ -0,0 +1,246 @@ +package clientip + +import ( + "context" + "errors" + "net/netip" + "testing" +) + +// mockSourceExtractor is a test double for sourceExtractor. +type mockSourceExtractor struct { + extractFn func(r requestView) (Extraction, error) + nameValue string + sourceValue Source +} + +func (m *mockSourceExtractor) extract(r requestView) (Extraction, error) { + return m.extractFn(r) +} + +func (m *mockSourceExtractor) name() string { + return m.nameValue +} + +func (m *mockSourceExtractor) sourceInfo() Source { + return m.sourceValue +} + +func TestChainedSource_ReturnsFirstSuccess(t *testing.T) { + wantIP := netip.MustParseAddr("1.2.3.4") + wantSource := SourceXForwardedFor + + first := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + return Extraction{IP: wantIP, Source: wantSource}, nil + }, + nameValue: "first", + } + second := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + t.Fatal("second source should not be called") + return Extraction{}, nil + }, + nameValue: "second", + } + + chain := newChainedSource(nil, first, second) + result, err := chain.extract(requestView{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } + if result.Source != wantSource { + t.Errorf("Source = %v, want %v", result.Source, wantSource) + } +} + +func TestChainedSource_SkipsNonTerminalErrors(t *testing.T) { + wantIP := netip.MustParseAddr("5.6.7.8") + nonTerminal := errors.New("not terminal") + + first := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + return Extraction{}, nonTerminal + }, + nameValue: "first", + } + second := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + return Extraction{IP: wantIP, Source: SourceRemoteAddr}, nil + }, + nameValue: "second", + } + + chain := newChainedSource(sourceIsTerminalError, first, second) + result, err := chain.extract(requestView{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } +} + +func TestChainedSource_SkipsErrSourceUnavailable(t *testing.T) { + wantIP := netip.MustParseAddr("10.0.0.1") + + unavailable := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + return Extraction{}, &ExtractionError{Err: ErrSourceUnavailable, Source: SourceXForwardedFor} + }, + nameValue: "unavailable", + } + fallback := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + return Extraction{IP: wantIP, Source: SourceRemoteAddr}, nil + }, + nameValue: "fallback", + } + + chain := newChainedSource(sourceIsTerminalError, unavailable, fallback) + result, err := chain.extract(requestView{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } +} + +func TestChainedSource_ContextCanceledIsTerminal(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + first := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + return Extraction{}, context.Canceled + }, + nameValue: "first", + } + second := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + t.Fatal("second source should not be called after terminal error") + return Extraction{}, nil + }, + nameValue: "second", + } + + chain := newChainedSource(sourceIsTerminalError, first, second) + _, err := chain.extract(requestView{ctx: ctx}) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, context.Canceled) { + t.Errorf("error = %v, want context.Canceled", err) + } +} + +func TestChainedSource_TerminalErrorStopsChain(t *testing.T) { + terminalErr := &ExtractionError{Err: ErrUntrustedProxy, Source: SourceXForwardedFor} + + first := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + return Extraction{}, terminalErr + }, + nameValue: "first", + } + second := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + t.Fatal("second source should not be called after terminal error") + return Extraction{}, nil + }, + nameValue: "second", + } + + chain := newChainedSource(sourceIsTerminalError, first, second) + _, err := chain.extract(requestView{}) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, ErrUntrustedProxy) { + t.Errorf("error = %v, want ErrUntrustedProxy", err) + } +} + +func TestChainedSource_AllFailReturnsLastError(t *testing.T) { + err1 := &ExtractionError{Err: ErrSourceUnavailable, Source: SourceXForwardedFor} + err2 := &ExtractionError{Err: ErrSourceUnavailable, Source: SourceRemoteAddr} + + first := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + return Extraction{}, err1 + }, + nameValue: "first", + } + second := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + return Extraction{}, err2 + }, + nameValue: "second", + } + + chain := newChainedSource(sourceIsTerminalError, first, second) + _, err := chain.extract(requestView{}) + if !errors.Is(err, err2) { + t.Errorf("error = %v, want %v (last error)", err, err2) + } +} + +func TestChainedSource_Name(t *testing.T) { + a := &mockSourceExtractor{nameValue: "alpha"} + b := &mockSourceExtractor{nameValue: "beta"} + chain := newChainedSource(nil, a, b) + + want := "chained[alpha,beta]" + if got := chain.name(); got != want { + t.Errorf("name() = %q, want %q", got, want) + } +} + +func TestChainedSource_SourceInfo(t *testing.T) { + chain := newChainedSource(nil, &mockSourceExtractor{nameValue: "a"}) + got := chain.sourceInfo() + if got.valid() { + t.Errorf("sourceInfo() should return invalid Source, got %v", got) + } +} + +func TestChainedSource_ContextCanceledBeforeSecondSource(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + firstCalled := false + secondCalled := false + first := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + firstCalled = true + cancel() // cancel context after first source runs + return Extraction{}, &ExtractionError{Err: ErrSourceUnavailable, Source: SourceXForwardedFor} + }, + nameValue: "first", + } + second := &mockSourceExtractor{ + extractFn: func(r requestView) (Extraction, error) { + secondCalled = true + return Extraction{}, nil + }, + nameValue: "second", + } + + chain := newChainedSource(sourceIsTerminalError, first, second) + _, err := chain.extract(requestView{ctx: ctx}) + if err == nil { + t.Fatal("expected error for cancelled context") + } + if !errors.Is(err, context.Canceled) { + t.Errorf("error = %v, want context.Canceled", err) + } + if !firstCalled { + t.Error("first source should have been called") + } + if secondCalled { + t.Error("second source should not be called after context cancellation") + } +} diff --git a/source_compile.go b/source_compile.go new file mode 100644 index 0000000..1c01863 --- /dev/null +++ b/source_compile.go @@ -0,0 +1,45 @@ +package clientip + +type sourceExtractor interface { + extract(r requestView) (Extraction, error) + name() string + sourceInfo() Source +} + +type sourceExtractorKind uint8 + +const ( + sourceExtractorKindForwarded sourceExtractorKind = iota + 1 + sourceExtractorKindXForwardedFor + sourceExtractorKindSingleHeader + sourceExtractorKindRemoteAddr +) + +type sourceSpec struct { + kind sourceExtractorKind + source Source + headerName string +} + +type sourceExecuteFunc func(requestView, sourceSpec) (Extraction, error) + +type compiledSource struct { + spec sourceSpec + execute sourceExecuteFunc +} + +func compileSource(spec sourceSpec, execute sourceExecuteFunc) sourceExtractor { + return &compiledSource{spec: spec, execute: execute} +} + +func (s *compiledSource) extract(r requestView) (Extraction, error) { + return s.execute(r, s.spec) +} + +func (s *compiledSource) name() string { + return s.spec.source.String() +} + +func (s *compiledSource) sourceInfo() Source { + return s.spec.source +} diff --git a/source_execution.go b/source_execution.go new file mode 100644 index 0000000..dad1636 --- /dev/null +++ b/source_execution.go @@ -0,0 +1,439 @@ +package clientip + +import ( + "context" + "errors" + "fmt" + "net/textproto" +) + +func (e *Extractor) buildSourceChain(cfg *config) sourceExtractor { + sources := make([]sourceExtractor, 0, len(cfg.sourcePriority)) + for _, configuredSource := range cfg.sourcePriority { + spec := compileSpecFromSource(configuredSource) + executor := e.compileExecutor(spec, configuredSource) + sources = append(sources, compileSource(spec, executor)) + } + + if len(sources) == 1 { + return sources[0] + } + + return newChainedSource(sourceIsTerminalError, sources...) +} + +func compileSpecFromSource(source Source) sourceSpec { + source = canonicalSource(source) + spec := sourceSpec{source: source} + + switch source.kind { + case sourceForwarded: + spec.kind = sourceExtractorKindForwarded + spec.headerName = "Forwarded" + case sourceXForwardedFor: + spec.kind = sourceExtractorKindXForwardedFor + spec.headerName = "X-Forwarded-For" + case sourceRemoteAddr: + spec.kind = sourceExtractorKindRemoteAddr + default: + spec.kind = sourceExtractorKindSingleHeader + headerName, _ := source.headerKey() + spec.headerName = textproto.CanonicalMIMEHeaderKey(headerName) + } + + return spec +} + +func (e *Extractor) compileExecutor(spec sourceSpec, configuredSource Source) sourceExecuteFunc { + source := canonicalSource(configuredSource) + // Pre-compute source name string once to avoid per-call allocations + // from normalizeSourceName (strings.ToLower + ReplaceAll). + sourceName := source.String() + // Pre-allocate the source-unavailable error once per source to avoid + // allocating on every fallback miss in multi-source chains. + sourceUnavailableErr := &ExtractionError{Err: ErrSourceUnavailable, Source: source} + + switch spec.kind { + case sourceExtractorKindForwarded: + ce := chainExtractor{policy: chainPolicy{ + headerName: "Forwarded", + parseValues: func(values []string) ([]string, error) { + parts, err := parseForwardedValues(values, e.config.maxChainLength) + if err != nil { + return nil, adaptForwardedParseError(err, source, e) + } + return parts, nil + }, + clientIP: e.clientIP, + trustedProxy: e.proxy, + selection: e.config.chainSelection, + collectDebugInfo: e.config.debugMode, + untrustedChainSep: ", ", + }} + return func(r requestView, _ sourceSpec) (Extraction, error) { + result, failure, err := ce.extract(r, source) + if err != nil { + e.handleChainError(r, source, err, + "Forwarded chain exceeds configured maximum length", + func(err error) { + if !errors.Is(err, ErrInvalidForwardedHeader) { + return + } + e.config.metrics.RecordSecurityEvent(SecurityEventMalformedForwarded) + e.logSecurityWarning(r, source, SecurityEventMalformedForwarded, "malformed Forwarded header received", "parse_error", err.Error()) + }, + ) + return Extraction{}, err + } + if failure != nil { + if failure.kind == failureSourceUnavailable { + return Extraction{}, sourceUnavailableErr + } + return Extraction{}, e.adaptChainFailure(r, source, failure, "request received from untrusted proxy while Forwarded is present") + } + e.config.metrics.RecordExtractionSuccess(sourceName) + return result, nil + } + + case sourceExtractorKindXForwardedFor: + ce := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: func(values []string) ([]string, error) { + parts, err := parseXFFValues(values, e.config.maxChainLength) + if err != nil { + return nil, adaptXFFParseError(err, source, e) + } + return parts, nil + }, + clientIP: e.clientIP, + trustedProxy: e.proxy, + selection: e.config.chainSelection, + collectDebugInfo: e.config.debugMode, + untrustedChainSep: ", ", + }} + return func(r requestView, _ sourceSpec) (Extraction, error) { + result, failure, err := ce.extract(r, source) + if err != nil { + e.handleChainError(r, source, err, + "X-Forwarded-For chain exceeds configured maximum length", + nil, + ) + return Extraction{}, err + } + if failure != nil { + if failure.kind == failureSourceUnavailable { + return Extraction{}, sourceUnavailableErr + } + return Extraction{}, e.adaptChainFailure(r, source, failure, "request received from untrusted proxy while X-Forwarded-For is present") + } + e.config.metrics.RecordExtractionSuccess(sourceName) + return result, nil + } + + case sourceExtractorKindRemoteAddr: + re := remoteAddrExtractor{clientIPPolicy: e.clientIP} + return func(r requestView, _ sourceSpec) (Extraction, error) { + result, failure := re.extract(r.remoteAddr(), source) + if failure != nil { + if failure.kind == failureSourceUnavailable { + return Extraction{}, sourceUnavailableErr + } + e.recordInvalidClientIPDisposition(failure.clientIPDisposition) + e.config.metrics.RecordExtractionFailure(sourceName) + return Extraction{}, adaptRemoteAddrFailure(failure, source) + } + e.config.metrics.RecordExtractionSuccess(sourceName) + return result, nil + } + + default: + headerName := spec.headerName + she := singleHeaderExtractor{policy: singleHeaderPolicy{ + headerName: headerName, + clientIP: e.clientIP, + trustedProxy: e.proxy, + }} + return func(r requestView, _ sourceSpec) (Extraction, error) { + result, failure := she.extract(r, source) + if failure != nil { + if failure.kind == failureSourceUnavailable { + return Extraction{}, sourceUnavailableErr + } + return Extraction{}, e.adaptSingleHeaderFailure(r, source, failure) + } + e.config.metrics.RecordExtractionSuccess(sourceName) + return result, nil + } + } +} + +func sourceIsTerminalError(err error) bool { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } + + if errors.Is(err, ErrSourceUnavailable) { + return false + } + + return errors.Is(err, ErrInvalidIP) || + errors.Is(err, ErrMultipleSingleIPHeaders) || + errors.Is(err, ErrUntrustedProxy) || + errors.Is(err, ErrNoTrustedProxies) || + errors.Is(err, ErrTooFewTrustedProxies) || + errors.Is(err, ErrTooManyTrustedProxies) || + errors.Is(err, ErrChainTooLong) || + errors.Is(err, ErrInvalidForwardedHeader) +} + +func (e *Extractor) logSecurityWarning(r requestView, source Source, event, msg string, attrs ...any) { + baseAttrs := []any{ + "event", event, + "source", source.String(), + "path", r.path(), + "remote_addr", r.remoteAddr(), + } + + baseAttrs = append(baseAttrs, attrs...) + e.config.logger.WarnContext(r.context(), msg, baseAttrs...) +} + +func proxyValidationWarningDetails(err error) (event, msg string, ok bool) { + switch { + case errors.Is(err, ErrNoTrustedProxies): + return SecurityEventNoTrustedProxies, "no trusted proxies found in request chain", true + case errors.Is(err, ErrTooFewTrustedProxies): + return SecurityEventTooFewTrustedProxies, "trusted proxy count below configured minimum", true + case errors.Is(err, ErrTooManyTrustedProxies): + return SecurityEventTooManyTrustedProxies, "trusted proxy count exceeds configured maximum", true + default: + return "", "", false + } +} + +func (e *Extractor) logProxyValidationWarning(r requestView, source Source, err error) { + event, msg, ok := proxyValidationWarningDetails(err) + if !ok { + return + } + + var proxyErr *ProxyValidationError + if errors.As(err, &proxyErr) { + e.logSecurityWarning(r, source, event, msg, + "trusted_proxy_count", proxyErr.TrustedProxyCount, + "min_trusted_proxies", proxyErr.MinTrustedProxies, + "max_trusted_proxies", proxyErr.MaxTrustedProxies, + ) + return + } + + e.logSecurityWarning(r, source, event, msg) +} + +func (e *Extractor) handleChainError( + r requestView, + source Source, + err error, + chainTooLongMessage string, + handleParseError func(error), +) { + if errors.Is(err, ErrChainTooLong) { + var chainErr *ChainTooLongError + if errors.As(err, &chainErr) { + e.logSecurityWarning(r, source, SecurityEventChainTooLong, chainTooLongMessage, + "chain_length", chainErr.ChainLength, + "max_length", chainErr.MaxLength, + ) + } else { + e.logSecurityWarning(r, source, SecurityEventChainTooLong, chainTooLongMessage) + } + } + + if handleParseError != nil { + handleParseError(err) + } + + e.config.metrics.RecordExtractionFailure(source.String()) +} + +func (e *Extractor) adaptChainFailure(r requestView, source Source, failure *extractionFailure, untrustedProxyMessage string) error { + if failure == nil { + return &ExtractionError{Err: ErrInvalidIP, Source: source} + } + + switch failure.kind { + case failureSourceUnavailable: + return &ExtractionError{Err: ErrSourceUnavailable, Source: source} + case failureUntrustedProxy: + e.config.metrics.RecordSecurityEvent(SecurityEventUntrustedProxy) + e.logSecurityWarning(r, source, SecurityEventUntrustedProxy, untrustedProxyMessage) + e.config.metrics.RecordExtractionFailure(source.String()) + return &ProxyValidationError{ + ExtractionError: ExtractionError{Err: ErrUntrustedProxy, Source: source}, + Chain: failure.chain, + TrustedProxyCount: failure.trustedProxyCount, + MinTrustedProxies: failure.minTrustedProxies, + MaxTrustedProxies: failure.maxTrustedProxies, + } + case failureProxyValidation: + err := &ProxyValidationError{ + ExtractionError: ExtractionError{ + Err: e.validateProxyCount(failure.trustedProxyCount), + Source: source, + }, + Chain: failure.chain, + TrustedProxyCount: failure.trustedProxyCount, + MinTrustedProxies: failure.minTrustedProxies, + MaxTrustedProxies: failure.maxTrustedProxies, + } + e.logProxyValidationWarning(r, source, err) + e.config.metrics.RecordExtractionFailure(source.String()) + return err + case failureEmptyChain: + e.config.metrics.RecordExtractionFailure(source.String()) + return &ExtractionError{Err: ErrInvalidIP, Source: source} + case failureInvalidClientIP: + e.recordInvalidClientIPDisposition(failure.clientIPDisposition) + e.config.metrics.RecordExtractionFailure(source.String()) + return &InvalidIPError{ + ExtractionError: ExtractionError{Err: ErrInvalidIP, Source: source}, + Chain: failure.chain, + ExtractedIP: failure.extractedIP, + Index: failure.index, + TrustedProxies: failure.trustedProxyCount, + } + default: + e.config.metrics.RecordExtractionFailure(source.String()) + return &ExtractionError{Err: ErrInvalidIP, Source: source} + } +} + +func (e *Extractor) adaptSingleHeaderFailure(r requestView, sourceName Source, failure *extractionFailure) error { + if failure == nil { + return &ExtractionError{Err: ErrInvalidIP, Source: sourceName} + } + + switch failure.kind { + case failureSourceUnavailable: + return &ExtractionError{Err: ErrSourceUnavailable, Source: sourceName} + case failureMultipleHeaders: + e.config.metrics.RecordSecurityEvent(SecurityEventMultipleHeaders) + e.logSecurityWarning(r, sourceName, SecurityEventMultipleHeaders, "multiple single-IP headers received - possible spoofing attempt", + "header", failure.headerName, + "header_count", failure.headerCount, + ) + e.config.metrics.RecordExtractionFailure(sourceName.String()) + return &MultipleHeadersError{ + ExtractionError: ExtractionError{Err: ErrMultipleSingleIPHeaders, Source: sourceName}, + HeaderCount: failure.headerCount, + HeaderName: failure.headerName, + RemoteAddr: failure.remoteAddr, + } + case failureUntrustedProxy: + e.config.metrics.RecordSecurityEvent(SecurityEventUntrustedProxy) + e.logSecurityWarning(r, sourceName, SecurityEventUntrustedProxy, "request received from untrusted proxy while single-header source is present", + "header", failure.headerName, + ) + e.config.metrics.RecordExtractionFailure(sourceName.String()) + return &ProxyValidationError{ + ExtractionError: ExtractionError{Err: ErrUntrustedProxy, Source: sourceName}, + Chain: failure.chain, + TrustedProxyCount: failure.trustedProxyCount, + MinTrustedProxies: failure.minTrustedProxies, + MaxTrustedProxies: failure.maxTrustedProxies, + } + case failureInvalidClientIP: + e.recordInvalidClientIPDisposition(failure.clientIPDisposition) + e.config.metrics.RecordExtractionFailure(sourceName.String()) + return &InvalidIPError{ + ExtractionError: ExtractionError{Err: ErrInvalidIP, Source: sourceName}, + ExtractedIP: failure.extractedIP, + } + default: + return &ExtractionError{Err: ErrInvalidIP, Source: sourceName} + } +} + +func adaptRemoteAddrFailure(failure *extractionFailure, sourceName Source) error { + if failure == nil { + return &ExtractionError{Err: ErrInvalidIP, Source: sourceName} + } + + switch failure.kind { + case failureSourceUnavailable: + return &ExtractionError{Err: ErrSourceUnavailable, Source: sourceName} + case failureInvalidClientIP: + return &RemoteAddrError{ + ExtractionError: ExtractionError{Err: ErrInvalidIP, Source: sourceName}, + RemoteAddr: failure.remoteAddr, + } + default: + return &ExtractionError{Err: ErrInvalidIP, Source: sourceName} + } +} + +func adaptForwardedParseError(err error, source Source, extractor *Extractor) error { + if chainErr := adaptChainLengthError(err, source, extractor); chainErr != nil { + return chainErr + } + + return &ExtractionError{ + Err: fmt.Errorf("%w: %w", ErrInvalidForwardedHeader, err), + Source: source, + } +} + +func adaptXFFParseError(err error, source Source, extractor *Extractor) error { + if chainErr := adaptChainLengthError(err, source, extractor); chainErr != nil { + return chainErr + } + + return err +} + +func adaptChainLengthError(err error, source Source, extractor *Extractor) error { + var chainErr *chainTooLongParseError + if !errors.As(err, &chainErr) { + return nil + } + + extractor.config.metrics.RecordSecurityEvent(SecurityEventChainTooLong) + + return &ChainTooLongError{ + ExtractionError: ExtractionError{Err: ErrChainTooLong, Source: source}, + ChainLength: chainErr.ChainLength, + MaxLength: chainErr.MaxLength, + } +} + +func (e *Extractor) validateProxyCount(trustedCount int) error { + err := validateProxyCountPolicy(trustedCount, e.proxy) + if err == nil { + return nil + } + + switch { + case errors.Is(err, ErrNoTrustedProxies): + e.config.metrics.RecordSecurityEvent(SecurityEventNoTrustedProxies) + return ErrNoTrustedProxies + case errors.Is(err, ErrTooFewTrustedProxies): + e.config.metrics.RecordSecurityEvent(SecurityEventTooFewTrustedProxies) + return ErrTooFewTrustedProxies + case errors.Is(err, ErrTooManyTrustedProxies): + e.config.metrics.RecordSecurityEvent(SecurityEventTooManyTrustedProxies) + return ErrTooManyTrustedProxies + default: + return err + } +} + +func (e *Extractor) recordInvalidClientIPDisposition(disposition clientIPDisposition) { + switch disposition { + case clientIPInvalid: + e.config.metrics.RecordSecurityEvent(SecurityEventInvalidIP) + case clientIPReserved: + e.config.metrics.RecordSecurityEvent(SecurityEventReservedIP) + case clientIPPrivate: + e.config.metrics.RecordSecurityEvent(SecurityEventPrivateIP) + } +} diff --git a/source_failure.go b/source_failure.go new file mode 100644 index 0000000..19af3d7 --- /dev/null +++ b/source_failure.go @@ -0,0 +1,32 @@ +package clientip + +type extractionFailureKind uint8 + +const ( + failureUnknown extractionFailureKind = iota + failureSourceUnavailable + failureMultipleHeaders + failureUntrustedProxy + failureProxyValidation + failureEmptyChain + failureInvalidClientIP +) + +// errSourceUnavailable is a pre-allocated sentinel returned by extractors +// when the source header is absent. Only the kind field is read by callers. +var errSourceUnavailable = &extractionFailure{kind: failureSourceUnavailable} + +type extractionFailure struct { + kind extractionFailureKind + source Source + headerName string + headerCount int + remoteAddr string + chain string + index int + extractedIP string + trustedProxyCount int + minTrustedProxies int + maxTrustedProxies int + clientIPDisposition clientIPDisposition +} diff --git a/source_forwarded.go b/source_forwarded.go deleted file mode 100644 index e33898b..0000000 --- a/source_forwarded.go +++ /dev/null @@ -1,103 +0,0 @@ -package clientip - -import ( - "context" - "errors" - "net/http" - "strings" -) - -type forwardedForSource struct { - extractor *Extractor - sourceName Source - unavailableErr error -} - -type forwardedSource struct { - extractor *Extractor - sourceName Source - unavailableErr error -} - -func (s *forwardedForSource) Name() string { - if !s.sourceName.valid() { - return builtinSource(sourceXForwardedFor).String() - } - - return s.sourceName.String() -} - -func (s *forwardedSource) Name() string { - if !s.sourceName.valid() { - return builtinSource(sourceForwarded).String() - } - - return s.sourceName.String() -} - -func (s *forwardedForSource) Source() Source { - if !s.sourceName.valid() { - return builtinSource(sourceXForwardedFor) - } - - return s.sourceName -} - -func (s *forwardedSource) Source() Source { - if !s.sourceName.valid() { - return builtinSource(sourceForwarded) - } - - return s.sourceName -} - -func (s *forwardedSource) Extract(ctx context.Context, r *http.Request) (extractionResult, error) { - sourceName := s.Source() - forwardedValues := r.Header["Forwarded"] - if len(forwardedValues) == 0 { - return extractionResult{}, sourceUnavailableError(s.unavailableErr, sourceName) - } - - return s.extractor.extractChainSource( - ctx, - r, - sourceName, - forwardedValues, - func() string { - return strings.Join(forwardedValues, ", ") - }, - "request received from untrusted proxy while Forwarded is present", - "Forwarded chain exceeds configured maximum length", - s.extractor.parseForwardedValues, - func(err error) { - if !errors.Is(err, ErrInvalidForwardedHeader) { - return - } - - s.extractor.config.metrics.RecordSecurityEvent(securityEventMalformedForwarded) - s.extractor.logSecurityWarning(ctx, r, sourceName, securityEventMalformedForwarded, "malformed Forwarded header received", "parse_error", err.Error()) - }, - ) -} - -func (s *forwardedForSource) Extract(ctx context.Context, r *http.Request) (extractionResult, error) { - sourceName := s.Source() - xffValues := r.Header["X-Forwarded-For"] - if len(xffValues) == 0 { - return extractionResult{}, sourceUnavailableError(s.unavailableErr, sourceName) - } - - return s.extractor.extractChainSource( - ctx, - r, - sourceName, - xffValues, - func() string { - return strings.Join(xffValues, ", ") - }, - "request received from untrusted proxy while X-Forwarded-For is present", - "X-Forwarded-For chain exceeds configured maximum length", - s.extractor.parseXFFValues, - nil, - ) -} diff --git a/source_forwarded_test.go b/source_forwarded_test.go deleted file mode 100644 index 56cc90a..0000000 --- a/source_forwarded_test.go +++ /dev/null @@ -1,205 +0,0 @@ -package clientip - -import ( - "context" - "net/http" - "net/netip" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestForwardedForSource_Extract(t *testing.T) { - extractor := mustNewExtractor(t) - source := &forwardedForSource{extractor: extractor} - - tests := []struct { - name string - xffHeaders []string - wantValid bool - wantIP string - wantErr error - wantErrType any - }{ - { - name: "single valid IP", - xffHeaders: []string{"1.1.1.1"}, - wantValid: true, - wantIP: "1.1.1.1", - }, - { - name: "multiple IPs in chain", - xffHeaders: []string{"1.1.1.1, 8.8.8.8"}, - wantValid: true, - wantIP: "1.1.1.1", - }, - { - name: "no XFF header", - xffHeaders: []string{}, - wantValid: false, - wantErrType: &ExtractionError{}, - }, - { - name: "multiple XFF headers are combined", - xffHeaders: []string{"1.1.1.1", "8.8.8.8"}, - wantValid: true, - wantIP: "1.1.1.1", - }, - { - name: "invalid IP in chain", - xffHeaders: []string{"not-an-ip"}, - wantValid: false, - }, - { - name: "private IP rejected", - xffHeaders: []string{"192.168.1.1"}, - wantValid: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := &http.Request{ - Header: make(http.Header), - } - for _, h := range tt.xffHeaders { - req.Header.Add("X-Forwarded-For", h) - } - - result, err := source.Extract(context.Background(), req) - - if tt.wantValid { - if err != nil { - t.Errorf("Extract() error = %v, want nil", err) - } - want := netip.MustParseAddr(tt.wantIP) - if result.IP != want { - t.Errorf("Extract() IP = %v, want %v", result.IP, want) - } - } else { - if err == nil { - t.Errorf("Extract() error = nil, want non-nil") - } - } - - if tt.wantErrType != nil { - if !errorIsType(err, tt.wantErrType) { - t.Errorf("Extract() error type = %T, want %T", err, tt.wantErrType) - } - } - - if tt.wantErr != nil { - if !errorContains(err, tt.wantErr) { - t.Errorf("Extract() error does not contain expected error: %v", tt.wantErr) - } - } - }) - } -} - -func TestForwardedForSource_Name(t *testing.T) { - extractor := mustNewExtractor(t) - source := &forwardedForSource{extractor: extractor} - - if source.Source() != SourceXForwardedFor { - t.Errorf("Source() = %q, want %q", source.Source(), SourceXForwardedFor) - } -} - -func TestForwardedSource_Extract(t *testing.T) { - extractor := mustNewExtractor(t) - source := &forwardedSource{extractor: extractor} - - tests := []struct { - name string - forwarded []string - wantValid bool - wantIP string - wantErr error - wantErrType any - }{ - { - name: "single valid for value", - forwarded: []string{"for=1.1.1.1"}, - wantValid: true, - wantIP: "1.1.1.1", - }, - { - name: "quoted IPv6 with port", - forwarded: []string{"for=\"[2606:4700:4700::1]:8080\""}, - wantValid: true, - wantIP: "2606:4700:4700::1", - }, - { - name: "no Forwarded header", - forwarded: nil, - wantValid: false, - wantErrType: &ExtractionError{}, - }, - { - name: "malformed Forwarded header", - forwarded: []string{"for=\"1.1.1.1"}, - wantValid: false, - wantErr: ErrInvalidForwardedHeader, - wantErrType: &ExtractionError{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := &http.Request{ - Header: make(http.Header), - } - for _, h := range tt.forwarded { - req.Header.Add("Forwarded", h) - } - - result, err := source.Extract(context.Background(), req) - - got := struct { - Valid bool - IP string - }{ - Valid: err == nil, - } - if err == nil { - got.IP = result.IP.String() - } - - want := struct { - Valid bool - IP string - }{ - Valid: tt.wantValid, - } - if tt.wantValid { - want.IP = netip.MustParseAddr(tt.wantIP).String() - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("Extract() mismatch (-want +got):\n%s", diff) - } - - if tt.wantErrType != nil { - if !errorIsType(err, tt.wantErrType) { - t.Errorf("Extract() error type = %T, want %T", err, tt.wantErrType) - } - } - - if tt.wantErr != nil { - if !errorContains(err, tt.wantErr) { - t.Errorf("Extract() error does not contain expected error: %v", tt.wantErr) - } - } - }) - } -} - -func TestForwardedSource_Name(t *testing.T) { - extractor := mustNewExtractor(t) - source := &forwardedSource{extractor: extractor} - - if source.Source() != SourceForwarded { - t.Errorf("Source() = %q, want %q", source.Source(), SourceForwarded) - } -} diff --git a/source_helpers.go b/source_helpers.go deleted file mode 100644 index 16fbd42..0000000 --- a/source_helpers.go +++ /dev/null @@ -1,19 +0,0 @@ -package clientip - -import "errors" - -func sourceUnavailableError(unavailableErr error, sourceName Source) error { - if unavailableErr != nil { - return unavailableErr - } - - return &ExtractionError{Err: ErrSourceUnavailable, Source: sourceName} -} - -func wrapSourceUnavailableError(err, unavailableErr error, sourceName Source) error { - if !errors.Is(err, ErrSourceUnavailable) { - return err - } - - return sourceUnavailableError(unavailableErr, sourceName) -} diff --git a/source_remote_addr_extract.go b/source_remote_addr_extract.go new file mode 100644 index 0000000..96f5def --- /dev/null +++ b/source_remote_addr_extract.go @@ -0,0 +1,27 @@ +package clientip + +type remoteAddrExtractor struct { + clientIPPolicy clientIPPolicy +} + +func (e remoteAddrExtractor) extract(remoteAddr string, source Source) (Extraction, *extractionFailure) { + if remoteAddr == "" { + return Extraction{}, errSourceUnavailable + } + + ip := parseRemoteAddr(remoteAddr) + disposition := evaluateClientIP(ip, e.clientIPPolicy) + if disposition != clientIPValid { + return Extraction{}, &extractionFailure{ + kind: failureInvalidClientIP, + source: source, + remoteAddr: remoteAddr, + clientIPDisposition: disposition, + } + } + + return Extraction{ + IP: normalizeIP(ip), + Source: source, + }, nil +} diff --git a/source_remote_addr_extract_test.go b/source_remote_addr_extract_test.go new file mode 100644 index 0000000..2fc7d80 --- /dev/null +++ b/source_remote_addr_extract_test.go @@ -0,0 +1,139 @@ +package clientip + +import ( + "net/netip" + "testing" +) + +func TestRemoteAddrExtractor_ValidAddr(t *testing.T) { + ext := remoteAddrExtractor{} + source := SourceRemoteAddr + + result, failure := ext.extract("8.8.8.8:8080", source) + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + wantIP := netip.MustParseAddr("8.8.8.8") + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } + if result.Source != source { + t.Errorf("Source = %v, want %v", result.Source, source) + } +} + +func TestRemoteAddrExtractor_ValidAddrWithoutPort(t *testing.T) { + ext := remoteAddrExtractor{} + source := SourceRemoteAddr + + result, failure := ext.extract("8.8.8.8", source) + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + wantIP := netip.MustParseAddr("8.8.8.8") + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } +} + +func TestRemoteAddrExtractor_IPv6(t *testing.T) { + ext := remoteAddrExtractor{} + source := SourceRemoteAddr + + result, failure := ext.extract("[2606:4700::1]:443", source) + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + wantIP := netip.MustParseAddr("2606:4700::1") + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } +} + +func TestRemoteAddrExtractor_EmptyAddr(t *testing.T) { + ext := remoteAddrExtractor{} + + _, failure := ext.extract("", SourceRemoteAddr) + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure != errSourceUnavailable { + t.Errorf("failure = %+v, want errSourceUnavailable", failure) + } +} + +func TestRemoteAddrExtractor_InvalidIP(t *testing.T) { + ext := remoteAddrExtractor{} + + _, failure := ext.extract("not-an-ip", SourceRemoteAddr) + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure.kind != failureInvalidClientIP { + t.Errorf("failure.kind = %v, want failureInvalidClientIP", failure.kind) + } +} + +func TestRemoteAddrExtractor_LoopbackIP(t *testing.T) { + ext := remoteAddrExtractor{} + + _, failure := ext.extract("127.0.0.1:8080", SourceRemoteAddr) + if failure == nil { + t.Fatal("expected failure for loopback, got nil") + } + if failure.kind != failureInvalidClientIP { + t.Errorf("failure.kind = %v, want failureInvalidClientIP", failure.kind) + } +} + +func TestRemoteAddrExtractor_UnspecifiedIP(t *testing.T) { + ext := remoteAddrExtractor{} + + _, failure := ext.extract("0.0.0.0:80", SourceRemoteAddr) + if failure == nil { + t.Fatal("expected failure for unspecified IP, got nil") + } + if failure.kind != failureInvalidClientIP { + t.Errorf("failure.kind = %v, want failureInvalidClientIP", failure.kind) + } +} + +func TestRemoteAddrExtractor_PrivateIPRejectedByDefault(t *testing.T) { + ext := remoteAddrExtractor{} + + _, failure := ext.extract("192.168.1.1:80", SourceRemoteAddr) + if failure == nil { + t.Fatal("expected failure for private IP with default policy, got nil") + } + if failure.kind != failureInvalidClientIP { + t.Errorf("failure.kind = %v, want failureInvalidClientIP", failure.kind) + } +} + +func TestRemoteAddrExtractor_PrivateIPAllowed(t *testing.T) { + ext := remoteAddrExtractor{ + clientIPPolicy: clientIPPolicy{AllowPrivateIPs: true}, + } + + result, failure := ext.extract("192.168.1.1:80", SourceRemoteAddr) + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + wantIP := netip.MustParseAddr("192.168.1.1") + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } +} + +func TestRemoteAddrExtractor_SourcePreserved(t *testing.T) { + ext := remoteAddrExtractor{} + source := SourceRemoteAddr + + _, failure := ext.extract("not-valid", source) + if failure == nil { + t.Fatal("expected failure") + } + if failure.source != source { + t.Errorf("failure.source = %v, want %v", failure.source, source) + } +} diff --git a/source_remote_addr_test.go b/source_remote_addr_test.go deleted file mode 100644 index 963ecb3..0000000 --- a/source_remote_addr_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package clientip - -import ( - "context" - "net/http" - "net/netip" - "testing" -) - -func TestRemoteAddrSource_Extract(t *testing.T) { - extractor := mustNewExtractor(t) - source := &remoteAddrSource{extractor: extractor} - - tests := []struct { - name string - remoteAddr string - wantValid bool - wantIP string - }{ - { - name: "valid IPv4 with port", - remoteAddr: "1.1.1.1:12345", - wantValid: true, - wantIP: "1.1.1.1", - }, - { - name: "valid IPv6 with port", - remoteAddr: "[2606:4700:4700::1]:8080", - wantValid: true, - wantIP: "2606:4700:4700::1", - }, - { - name: "empty RemoteAddr", - remoteAddr: "", - wantValid: false, - }, - { - name: "loopback rejected", - remoteAddr: "127.0.0.1:8080", - wantValid: false, - }, - { - name: "private IP rejected", - remoteAddr: "192.168.1.1:8080", - wantValid: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := &http.Request{ - RemoteAddr: tt.remoteAddr, - } - - result, err := source.Extract(context.Background(), req) - - if tt.wantValid { - if err != nil { - t.Errorf("Extract() error = %v, want nil", err) - } - want := netip.MustParseAddr(tt.wantIP) - if result.IP != want { - t.Errorf("Extract() IP = %v, want %v", result.IP, want) - } - } else { - if err == nil { - t.Errorf("Extract() error = nil, want non-nil") - } - } - }) - } -} - -func TestRemoteAddrSource_Name(t *testing.T) { - extractor := mustNewExtractor(t) - source := &remoteAddrSource{extractor: extractor} - - if source.Source() != SourceRemoteAddr { - t.Errorf("Source() = %q, want %q", source.Source(), SourceRemoteAddr) - } -} diff --git a/source_request.go b/source_request.go new file mode 100644 index 0000000..9822f5d --- /dev/null +++ b/source_request.go @@ -0,0 +1,104 @@ +package clientip + +import ( + "context" + "net/http" + "net/textproto" +) + +type headerValuesFunc func(name string) []string + +type requestView struct { + ctx context.Context + remoteAddrValue string + pathValue string + headerMap map[string][]string + headerFunc headerValuesFunc +} + +func (r requestView) context() context.Context { + if r.ctx == nil { + return context.Background() + } + + return r.ctx +} + +func (r requestView) remoteAddr() string { + return r.remoteAddrValue +} + +func (r requestView) path() string { + return r.pathValue +} + +func (r requestView) values(name string) []string { + if r.headerMap != nil { + return r.headerMap[textproto.CanonicalMIMEHeaderKey(name)] + } + if r.headerFunc != nil { + return r.headerFunc(name) + } + + return nil +} + +// valuesCanonical performs a header lookup without canonicalizing the name. +// Callers must pass an already-canonical MIME header key (e.g. "X-Forwarded-For"). +func (r requestView) valuesCanonical(name string) []string { + if r.headerMap != nil { + return r.headerMap[name] + } + if r.headerFunc != nil { + return r.headerFunc(name) + } + + return nil +} + +func requestViewFromRequest(r *http.Request) requestView { + if r == nil { + return requestView{} + } + + view := requestView{ + ctx: r.Context(), + remoteAddrValue: r.RemoteAddr, + headerMap: map[string][]string(r.Header), + } + if r.URL != nil { + view.pathValue = r.URL.Path + } + + return view +} + +func requestViewFromInput(input Input) requestView { + view := requestView{ + ctx: requestInputContext(input), + remoteAddrValue: input.RemoteAddr, + pathValue: input.Path, + } + if input.Headers == nil { + return view + } + + if h, ok := input.Headers.(HeaderValuesFunc); ok { + if h == nil { + return view + } + view.headerFunc = headerValuesFunc(h) + return view + } + + // Deliberately catch typed nils (e.g. (*myHeaders)(nil)) so they behave + // the same as an unset Headers field rather than panicking at call time. + if isNilValue(input.Headers) { + return view + } + + view.headerFunc = func(name string) []string { + return input.Headers.Values(name) + } + return view +} diff --git a/source_single_header.go b/source_single_header.go index eae00ef..1a6db78 100644 --- a/source_single_header.go +++ b/source_single_header.go @@ -1,145 +1,64 @@ package clientip -import ( - "context" - "net/http" -) - -type singleHeaderSource struct { - extractor *Extractor - headerName string - headerKey string - sourceName Source - unavailableErr error -} - -func (s *singleHeaderSource) Name() string { - return s.sourceName.String() +type singleHeaderPolicy struct { + headerName string + clientIP clientIPPolicy + trustedProxy proxyPolicy } -func (s *singleHeaderSource) Source() Source { - return s.sourceName +type singleHeaderExtractor struct { + policy singleHeaderPolicy } -func (s *singleHeaderSource) Extract(ctx context.Context, r *http.Request) (extractionResult, error) { - sourceName := s.Source() - headerValues := r.Header[s.headerKey] +func (e singleHeaderExtractor) extract(req requestView, source Source) (Extraction, *extractionFailure) { + headerValues := req.valuesCanonical(e.policy.headerName) if len(headerValues) == 0 { - return extractionResult{}, sourceUnavailableError(s.unavailableErr, sourceName) + return Extraction{}, errSourceUnavailable } if len(headerValues) > 1 { - s.extractor.config.metrics.RecordSecurityEvent(securityEventMultipleHeaders) - s.extractor.logSecurityWarning(ctx, r, sourceName, securityEventMultipleHeaders, "multiple single-IP headers received - possible spoofing attempt", - "header", s.headerName, - "header_count", len(headerValues), - ) - s.extractor.config.metrics.RecordExtractionFailure(sourceName.String()) - return extractionResult{}, &MultipleHeadersError{ - ExtractionError: ExtractionError{ - Err: ErrMultipleSingleIPHeaders, - Source: sourceName, - }, - HeaderCount: len(headerValues), - HeaderName: s.headerName, - RemoteAddr: r.RemoteAddr, + return Extraction{}, &extractionFailure{ + kind: failureMultipleHeaders, + source: source, + headerName: e.policy.headerName, + headerCount: len(headerValues), + remoteAddr: req.remoteAddr(), } } headerValue := headerValues[0] if headerValue == "" { - return extractionResult{}, sourceUnavailableError(s.unavailableErr, sourceName) + return Extraction{}, errSourceUnavailable } - if len(s.extractor.config.trustedProxyCIDRs) > 0 { - remoteIP := parseRemoteAddr(r.RemoteAddr) - if !s.extractor.isTrustedProxy(remoteIP) { - s.extractor.config.metrics.RecordSecurityEvent(securityEventUntrustedProxy) - s.extractor.logSecurityWarning(ctx, r, sourceName, securityEventUntrustedProxy, "request received from untrusted proxy while single-header source is present", - "header", s.headerName, - ) - s.extractor.config.metrics.RecordExtractionFailure(sourceName.String()) - return extractionResult{}, &ProxyValidationError{ - ExtractionError: ExtractionError{ - Err: ErrUntrustedProxy, - Source: sourceName, - }, - Chain: headerValue, - TrustedProxyCount: 0, - MinTrustedProxies: s.extractor.config.minTrustedProxies, - MaxTrustedProxies: s.extractor.config.maxTrustedProxies, + if len(e.policy.trustedProxy.TrustedProxyCIDRs) > 0 { + remoteIP := parseRemoteAddr(req.remoteAddr()) + if !isTrustedProxy(remoteIP, e.policy.trustedProxy.TrustedProxyMatch, e.policy.trustedProxy.TrustedProxyCIDRs) { + return Extraction{}, &extractionFailure{ + kind: failureUntrustedProxy, + source: source, + headerName: e.policy.headerName, + chain: headerValue, + trustedProxyCount: 0, + minTrustedProxies: e.policy.trustedProxy.MinTrustedProxies, + maxTrustedProxies: e.policy.trustedProxy.MaxTrustedProxies, } } } ip := parseIP(headerValue) - if !s.extractor.isPlausibleClientIP(ip) { - s.extractor.config.metrics.RecordExtractionFailure(sourceName.String()) - return extractionResult{}, &InvalidIPError{ - ExtractionError: ExtractionError{ - Err: ErrInvalidIP, - Source: sourceName, - }, - ExtractedIP: headerValue, - } - } - - s.extractor.config.metrics.RecordExtractionSuccess(sourceName.String()) - return extractionResult{IP: normalizeIP(ip), Source: sourceName}, nil -} - -type remoteAddrSource struct { - extractor *Extractor - sourceName Source - unavailableErr error -} - -func (s *remoteAddrSource) Name() string { - if !s.sourceName.valid() { - return builtinSource(sourceRemoteAddr).String() - } - - return s.sourceName.String() -} - -func (s *remoteAddrSource) Source() Source { - if !s.sourceName.valid() { - return builtinSource(sourceRemoteAddr) - } - - return s.sourceName -} - -func (s *remoteAddrSource) Extract(ctx context.Context, r *http.Request) (extractionResult, error) { - sourceName := s.Source() - remoteAddr := r.RemoteAddr - result, err := s.extractor.extractRemoteAddr(remoteAddr) - if err != nil { - return extractionResult{}, wrapSourceUnavailableError(err, s.unavailableErr, sourceName) - } - - return result, nil -} - -func (e *Extractor) extractRemoteAddr(remoteAddr string) (extractionResult, error) { - remoteSource := builtinSource(sourceRemoteAddr) - - if remoteAddr == "" { - return extractionResult{}, &ExtractionError{Err: ErrSourceUnavailable, Source: remoteSource} - } - - ip := parseRemoteAddr(remoteAddr) - if !e.isPlausibleClientIP(ip) { - e.config.metrics.RecordExtractionFailure(remoteSource.String()) - return extractionResult{}, &RemoteAddrError{ - ExtractionError: ExtractionError{ - Err: ErrInvalidIP, - Source: remoteSource, - }, - RemoteAddr: remoteAddr, + disposition := evaluateClientIP(ip, e.policy.clientIP) + if disposition != clientIPValid { + return Extraction{}, &extractionFailure{ + kind: failureInvalidClientIP, + source: source, + extractedIP: headerValue, + clientIPDisposition: disposition, } } - e.config.metrics.RecordExtractionSuccess(remoteSource.String()) - return extractionResult{IP: normalizeIP(ip), Source: remoteSource}, nil + return Extraction{ + IP: normalizeIP(ip), + Source: source, + }, nil } diff --git a/source_single_header_test.go b/source_single_header_test.go index 57964fe..b214012 100644 --- a/source_single_header_test.go +++ b/source_single_header_test.go @@ -1,176 +1,222 @@ package clientip import ( - "context" - "errors" - "net/http" "net/netip" - "net/textproto" "testing" ) -func TestSingleHeaderSource_Extract(t *testing.T) { - extractor := mustNewExtractor(t) - - tests := []struct { - name string - headerName string - headerValue string - wantValid bool - wantIP string - }{ - { - name: "valid IP", - headerName: "X-Real-IP", - headerValue: "1.1.1.1", - wantValid: true, - wantIP: "1.1.1.1", - }, - { - name: "IPv6", - headerName: "X-Real-IP", - headerValue: "2606:4700:4700::1", - wantValid: true, - wantIP: "2606:4700:4700::1", - }, - { - name: "empty header", - headerName: "X-Real-IP", - headerValue: "", - wantValid: false, - }, - { - name: "invalid IP", - headerName: "X-Real-IP", - headerValue: "not-an-ip", - wantValid: false, - }, - { - name: "private IP rejected", - headerName: "X-Real-IP", - headerValue: "192.168.1.1", - wantValid: false, - }, - { - name: "custom header name", - headerName: "CF-Connecting-IP", - headerValue: "1.1.1.1", - wantValid: true, - wantIP: "1.1.1.1", +func TestSingleHeaderExtractor_ValidIP(t *testing.T) { + ext := singleHeaderExtractor{policy: singleHeaderPolicy{ + headerName: "X-Real-Ip", + }} + source := SourceXRealIP + req := requestView{ + remoteAddrValue: "10.0.0.1:1234", + headerMap: map[string][]string{ + "X-Real-Ip": {"8.8.8.8"}, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - source := &singleHeaderSource{ - extractor: extractor, - headerName: tt.headerName, - headerKey: textproto.CanonicalMIMEHeaderKey(tt.headerName), - sourceName: HeaderSource(tt.headerName), - } - - req := &http.Request{ - Header: make(http.Header), - } - if tt.headerValue != "" { - req.Header.Set(tt.headerName, tt.headerValue) - } - - result, err := source.Extract(context.Background(), req) - - if tt.wantValid { - if err != nil { - t.Errorf("Extract() error = %v, want nil", err) - } - want := netip.MustParseAddr(tt.wantIP) - if result.IP != want { - t.Errorf("Extract() IP = %v, want %v", result.IP, want) - } - } else { - if err == nil { - t.Errorf("Extract() error = nil, want non-nil") - } - } - }) + result, failure := ext.extract(req, source) + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + wantIP := netip.MustParseAddr("8.8.8.8") + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } + if result.Source != source { + t.Errorf("Source = %v, want %v", result.Source, source) } } -func TestSingleHeaderSource_Extract_MultipleHeaderValues(t *testing.T) { - extractor, err := New() - if err != nil { - t.Fatalf("New() error = %v", err) +func TestSingleHeaderExtractor_HeaderMissing(t *testing.T) { + ext := singleHeaderExtractor{policy: singleHeaderPolicy{ + headerName: "X-Real-Ip", + }} + req := requestView{ + headerMap: map[string][]string{}, } - source := &singleHeaderSource{ - extractor: extractor, - headerName: "X-Real-IP", - headerKey: textproto.CanonicalMIMEHeaderKey("X-Real-IP"), - sourceName: HeaderSource("X-Real-IP"), + _, failure := ext.extract(req, SourceXRealIP) + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure != errSourceUnavailable { + t.Errorf("failure = %+v, want errSourceUnavailable", failure) } +} - req := &http.Request{ - RemoteAddr: "127.0.0.1:8080", - Header: make(http.Header), +func TestSingleHeaderExtractor_EmptyHeaderValue(t *testing.T) { + ext := singleHeaderExtractor{policy: singleHeaderPolicy{ + headerName: "X-Real-Ip", + }} + req := requestView{ + headerMap: map[string][]string{ + "X-Real-Ip": {""}, + }, } - req.Header.Add("X-Real-IP", "1.1.1.1") - req.Header.Add("X-Real-IP", "8.8.8.8") - _, extractErr := source.Extract(context.Background(), req) - if extractErr == nil { - t.Fatal("Extract() error = nil, want error") + _, failure := ext.extract(req, SourceXRealIP) + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure != errSourceUnavailable { + t.Errorf("failure = %+v, want errSourceUnavailable", failure) } +} - if !errors.Is(extractErr, ErrMultipleSingleIPHeaders) { - t.Fatalf("error = %v, want ErrMultipleSingleIPHeaders", extractErr) +func TestSingleHeaderExtractor_MultipleHeaderValues(t *testing.T) { + ext := singleHeaderExtractor{policy: singleHeaderPolicy{ + headerName: "X-Real-Ip", + }} + req := requestView{ + remoteAddrValue: "10.0.0.1:1234", + headerMap: map[string][]string{ + "X-Real-Ip": {"8.8.4.4", "1.1.1.1"}, + }, } - var multipleHeadersErr *MultipleHeadersError - if !errors.As(extractErr, &multipleHeadersErr) { - t.Fatalf("error type = %T, want *MultipleHeadersError", extractErr) + _, failure := ext.extract(req, SourceXRealIP) + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure.kind != failureMultipleHeaders { + t.Errorf("failure.kind = %v, want failureMultipleHeaders", failure.kind) + } + if failure.headerCount != 2 { + t.Errorf("failure.headerCount = %d, want 2", failure.headerCount) } + if failure.headerName != "X-Real-Ip" { + t.Errorf("failure.headerName = %q, want %q", failure.headerName, "X-Real-Ip") + } +} - if multipleHeadersErr.HeaderCount != 2 { - t.Fatalf("HeaderCount = %d, want 2", multipleHeadersErr.HeaderCount) +func TestSingleHeaderExtractor_UntrustedProxy(t *testing.T) { + trustedCIDR := netip.MustParsePrefix("10.0.0.0/8") + ext := singleHeaderExtractor{policy: singleHeaderPolicy{ + headerName: "X-Real-Ip", + trustedProxy: proxyPolicy{ + TrustedProxyCIDRs: []netip.Prefix{trustedCIDR}, + TrustedProxyMatch: newPrefixMatcher([]netip.Prefix{trustedCIDR}), + }, + }} + // Remote addr is not in trusted CIDR. + req := requestView{ + remoteAddrValue: "5.5.5.5:4567", + headerMap: map[string][]string{ + "X-Real-Ip": {"9.9.9.9"}, + }, } - if multipleHeadersErr.HeaderName != "X-Real-IP" { - t.Fatalf("HeaderName = %q, want %q", multipleHeadersErr.HeaderName, "X-Real-IP") + _, failure := ext.extract(req, SourceXRealIP) + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure.kind != failureUntrustedProxy { + t.Errorf("failure.kind = %v, want failureUntrustedProxy", failure.kind) } } -func TestSingleHeaderSource_Name(t *testing.T) { - extractor := mustNewExtractor(t) +func TestSingleHeaderExtractor_TrustedProxy(t *testing.T) { + trustedCIDR := netip.MustParsePrefix("10.0.0.0/8") + ext := singleHeaderExtractor{policy: singleHeaderPolicy{ + headerName: "X-Real-Ip", + trustedProxy: proxyPolicy{ + TrustedProxyCIDRs: []netip.Prefix{trustedCIDR}, + TrustedProxyMatch: newPrefixMatcher([]netip.Prefix{trustedCIDR}), + }, + }} + // Remote addr IS in trusted CIDR. + req := requestView{ + remoteAddrValue: "10.0.0.1:4567", + headerMap: map[string][]string{ + "X-Real-Ip": {"9.9.9.9"}, + }, + } - tests := []struct { - headerName string - wantName string - }{ - { - headerName: "X-Real-IP", - wantName: "x_real_ip", + result, failure := ext.extract(req, SourceXRealIP) + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + wantIP := netip.MustParseAddr("9.9.9.9") + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } +} + +func TestSingleHeaderExtractor_InvalidClientIP(t *testing.T) { + ext := singleHeaderExtractor{policy: singleHeaderPolicy{ + headerName: "X-Real-Ip", + }} + req := requestView{ + headerMap: map[string][]string{ + "X-Real-Ip": {"not-an-ip"}, }, - { - headerName: "CF-Connecting-IP", - wantName: "cf_connecting_ip", + } + + _, failure := ext.extract(req, SourceXRealIP) + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure.kind != failureInvalidClientIP { + t.Errorf("failure.kind = %v, want failureInvalidClientIP", failure.kind) + } + if failure.extractedIP != "not-an-ip" { + t.Errorf("failure.extractedIP = %q, want %q", failure.extractedIP, "not-an-ip") + } +} + +func TestSingleHeaderExtractor_LoopbackIsInvalidClient(t *testing.T) { + ext := singleHeaderExtractor{policy: singleHeaderPolicy{ + headerName: "X-Real-Ip", + }} + req := requestView{ + headerMap: map[string][]string{ + "X-Real-Ip": {"127.0.0.1"}, }, - { - headerName: "X-Custom-Header", - wantName: "x_custom_header", + } + + _, failure := ext.extract(req, SourceXRealIP) + if failure == nil { + t.Fatal("expected failure for loopback IP, got nil") + } + if failure.kind != failureInvalidClientIP { + t.Errorf("failure.kind = %v, want failureInvalidClientIP", failure.kind) + } +} + +func TestSingleHeaderExtractor_IPv6(t *testing.T) { + ext := singleHeaderExtractor{policy: singleHeaderPolicy{ + headerName: "X-Real-Ip", + }} + req := requestView{ + headerMap: map[string][]string{ + "X-Real-Ip": {"2606:4700::1"}, }, } - for _, tt := range tests { - t.Run(tt.headerName, func(t *testing.T) { - source := &singleHeaderSource{ - extractor: extractor, - headerName: tt.headerName, - headerKey: textproto.CanonicalMIMEHeaderKey(tt.headerName), - sourceName: HeaderSource(tt.headerName), - } + result, failure := ext.extract(req, SourceXRealIP) + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + wantIP := netip.MustParseAddr("2606:4700::1") + if result.IP != wantIP { + t.Errorf("IP = %v, want %v", result.IP, wantIP) + } +} - if got := source.Source().String(); got != tt.wantName { - t.Errorf("Source().String() = %q, want %q", got, tt.wantName) - } - }) +func TestSingleHeaderExtractor_NoHeadersAtAll(t *testing.T) { + ext := singleHeaderExtractor{policy: singleHeaderPolicy{ + headerName: "X-Custom", + }} + req := requestView{} + + _, failure := ext.extract(req, HeaderSource("X-Custom")) + if failure == nil { + t.Fatal("expected failure, got nil") + } + if failure != errSourceUnavailable { + t.Errorf("failure = %+v, want errSourceUnavailable", failure) } } diff --git a/test_helpers_test.go b/test_helpers_test.go index e2a7da2..871aee9 100644 --- a/test_helpers_test.go +++ b/test_helpers_test.go @@ -62,13 +62,6 @@ func asError(err error, target any) bool { } } -func errorContains(err, target error) bool { - if err == nil { - return false - } - return errors.Is(err, target) -} - func extractionStateOf(extraction Extraction) extractionState { state := extractionState{ HasIP: extraction.IP.IsValid(), @@ -89,10 +82,10 @@ func errorTextStateOf(err error, contains string) errorTextState { } } -func mustNewExtractor(t *testing.T, opts ...Option) *Extractor { +func mustNewExtractor(t *testing.T, cfg Config) *Extractor { t.Helper() - extractor, err := New(opts...) + extractor, err := New(cfg) if err != nil { t.Fatalf("New() error = %v", err) } @@ -100,6 +93,17 @@ func mustNewExtractor(t *testing.T, opts ...Option) *Extractor { return extractor } +func mustProxyPrefixesFromAddrs(t *testing.T, addrs ...netip.Addr) []netip.Prefix { + t.Helper() + + prefixes, err := ProxyPrefixesFromAddrs(addrs...) + if err != nil { + t.Fatalf("ProxyPrefixesFromAddrs() error = %v", err) + } + + return prefixes +} + func mustParseCIDRs(t *testing.T, cidrs ...string) []netip.Prefix { t.Helper() diff --git a/trust_benchmark_test.go b/trust_benchmark_test.go new file mode 100644 index 0000000..168c2f0 --- /dev/null +++ b/trust_benchmark_test.go @@ -0,0 +1,103 @@ +package clientip + +import ( + "net/netip" + "testing" +) + +func BenchmarkIsTrustedProxy(b *testing.B) { + cidrs := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("172.16.0.0/12"), + netip.MustParsePrefix("192.168.0.0/16"), + } + matcher := newPrefixMatcher(cidrs) + testIPs := []netip.Addr{ + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("172.16.0.1"), + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("1.1.1.1"), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, ip := range testIPs { + isTrustedProxy(ip, matcher, cidrs) + } + } +} + +func BenchmarkIsTrustedProxyLargeCIDRSetPrecomputed(b *testing.B) { + const prefixCount = 4096 + prefixes := make([]netip.Prefix, 0, prefixCount) + for i := 0; i < prefixCount; i++ { + secondOctet := byte((i / 16) % 256) + thirdOctet := byte(i % 256) + prefixes = append(prefixes, netip.PrefixFrom(netip.AddrFrom4([4]byte{10, secondOctet, thirdOctet, 0}), 24)) + } + + matcher := newPrefixMatcher(prefixes) + ip := netip.MustParseAddr("10.128.8.8") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if !isTrustedProxy(ip, matcher, prefixes) { + b.Fatal("expected trusted proxy") + } + } +} + +func BenchmarkIsTrustedProxyLargeCIDRSetLinearFallback(b *testing.B) { + const prefixCount = 4096 + prefixes := make([]netip.Prefix, 0, prefixCount) + for i := 0; i < prefixCount; i++ { + secondOctet := byte((i / 16) % 256) + thirdOctet := byte(i % 256) + prefixes = append(prefixes, netip.PrefixFrom(netip.AddrFrom4([4]byte{10, secondOctet, thirdOctet, 0}), 24)) + } + + ip := netip.MustParseAddr("10.128.8.8") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if !isTrustedProxy(ip, prefixMatcher{}, prefixes) { + b.Fatal("expected trusted proxy") + } + } +} + +func BenchmarkChainAnalysisRightmost(b *testing.B) { + policy := proxyPolicy{ + TrustedProxyCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + MinTrustedProxies: 1, + MaxTrustedProxies: 3, + } + policy.TrustedProxyMatch = newPrefixMatcher(policy.TrustedProxyCIDRs) + parts := []string{"1.1.1.1", "8.8.8.8", "10.0.0.1", "10.0.0.2"} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := analyzeChainRightmost(parts, policy, true) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkChainAnalysisLeftmost(b *testing.B) { + policy := proxyPolicy{ + TrustedProxyCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + MinTrustedProxies: 1, + MaxTrustedProxies: 3, + } + policy.TrustedProxyMatch = newPrefixMatcher(policy.TrustedProxyCIDRs) + parts := []string{"1.1.1.1", "8.8.8.8", "10.0.0.1", "10.0.0.2"} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := analyzeChainLeftmost(parts, policy, true) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/trust_chain.go b/trust_chain.go new file mode 100644 index 0000000..11fb9af --- /dev/null +++ b/trust_chain.go @@ -0,0 +1,170 @@ +package clientip + +import ( + "net/netip" +) + +type proxyPolicy struct { + TrustedProxyCIDRs []netip.Prefix + TrustedProxyMatch prefixMatcher + MinTrustedProxies int + MaxTrustedProxies int +} + +type chainAnalysis struct { + ClientIndex int + TrustedCount int + TrustedIndices []int +} + +func isTrustedProxy(ip netip.Addr, matcher prefixMatcher, cidrs []netip.Prefix) bool { + if !ip.IsValid() { + return false + } + + if matcher.initialized { + return matcher.contains(ip) + } + + for _, cidr := range cidrs { + if cidr.Contains(ip) { + return true + } + } + + return false +} + +func validateProxyCountPolicy(trustedCount int, policy proxyPolicy) error { + if len(policy.TrustedProxyCIDRs) > 0 && policy.MinTrustedProxies > 0 && trustedCount == 0 { + return ErrNoTrustedProxies + } + + if policy.MinTrustedProxies > 0 && trustedCount < policy.MinTrustedProxies { + return ErrTooFewTrustedProxies + } + + if policy.MaxTrustedProxies > 0 && trustedCount > policy.MaxTrustedProxies { + return ErrTooManyTrustedProxies + } + + return nil +} + +func analyzeChainRightmost(parts []string, policy proxyPolicy, collectTrustedIndices bool) (chainAnalysis, netip.Addr, error) { + trustedCount := 0 + clientIndex := 0 + clientIP := netip.Addr{} + + var trustedIndices []int + if collectTrustedIndices { + trustedIndices = make([]int, 0, len(parts)) + } + + hasCIDRs := len(policy.TrustedProxyCIDRs) > 0 + + for i := len(parts) - 1; i >= 0; i-- { + if !hasCIDRs && policy.MaxTrustedProxies > 0 && trustedCount >= policy.MaxTrustedProxies { + clientIndex = i + clientIP = parseChainIP(parts[i]) + break + } + + ip := parseChainIP(parts[i]) + + if hasCIDRs && !isTrustedProxy(ip, policy.TrustedProxyMatch, policy.TrustedProxyCIDRs) { + clientIndex = i + clientIP = ip + break + } + + if collectTrustedIndices { + trustedIndices = append(trustedIndices, i) + } + trustedCount++ + clientIP = ip + } + + analysis := chainAnalysis{ + ClientIndex: clientIndex, + TrustedCount: trustedCount, + TrustedIndices: trustedIndices, + } + + if err := validateProxyCountPolicy(trustedCount, policy); err != nil { + return analysis, netip.Addr{}, err + } + + return analysis, clientIP, nil +} + +func analyzeChainLeftmost(parts []string, policy proxyPolicy, collectTrustedIndices bool) (chainAnalysis, netip.Addr, error) { + if len(policy.TrustedProxyCIDRs) == 0 { + analysis := chainAnalysis{ClientIndex: 0, TrustedCount: 0} + return analysis, parseIP(parts[0]), nil + } + + trustedCount := 0 + leftmostUntrustedIndex := -1 + leftmostUntrustedIP := netip.Addr{} + hasLeftmostUntrusted := false + + fallbackClientIndex := 0 + fallbackClientIP := netip.Addr{} + hasFallbackClient := false + + var trustedIndices []int + if collectTrustedIndices { + trustedIndices = make([]int, 0, len(parts)) + } + + stillTrailingTrusted := true + + for i := len(parts) - 1; i >= 0; i-- { + ip := parseChainIP(parts[i]) + trusted := isTrustedProxy(ip, policy.TrustedProxyMatch, policy.TrustedProxyCIDRs) + + if stillTrailingTrusted && trusted { + if collectTrustedIndices { + trustedIndices = append(trustedIndices, i) + } + trustedCount++ + continue + } + + if stillTrailingTrusted { + fallbackClientIndex = i + fallbackClientIP = ip + hasFallbackClient = true + } + + stillTrailingTrusted = false + if !trusted { + leftmostUntrustedIndex = i + leftmostUntrustedIP = ip + hasLeftmostUntrusted = true + } + } + + analysis := chainAnalysis{TrustedCount: trustedCount} + if collectTrustedIndices { + analysis.TrustedIndices = trustedIndices + } + + if err := validateProxyCountPolicy(trustedCount, policy); err != nil { + return analysis, netip.Addr{}, err + } + + if hasLeftmostUntrusted { + analysis.ClientIndex = leftmostUntrustedIndex + return analysis, leftmostUntrustedIP, nil + } + + if hasFallbackClient { + analysis.ClientIndex = fallbackClientIndex + return analysis, fallbackClientIP, nil + } + + analysis.ClientIndex = 0 + return analysis, parseChainIP(parts[analysis.ClientIndex]), nil +} diff --git a/trust_chain_test.go b/trust_chain_test.go new file mode 100644 index 0000000..57c3697 --- /dev/null +++ b/trust_chain_test.go @@ -0,0 +1,184 @@ +package clientip + +import ( + "errors" + "net/netip" + "testing" +) + +func TestValidateProxyCount(t *testing.T) { + tests := []struct { + name string + minProxies int + maxProxies int + trustedCIDRs []netip.Prefix + trustedCount int + wantErr error + }{ + {name: "within range", minProxies: 1, maxProxies: 3, trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, trustedCount: 2}, + {name: "at minimum", minProxies: 1, maxProxies: 3, trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, trustedCount: 1}, + {name: "at maximum", minProxies: 1, maxProxies: 3, trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, trustedCount: 3}, + {name: "no trusted proxies allowed when minimum is zero", minProxies: 0, maxProxies: 3, trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, trustedCount: 0}, + {name: "no trusted proxies with minimum requirement", minProxies: 1, maxProxies: 3, trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, trustedCount: 0, wantErr: ErrNoTrustedProxies}, + {name: "below minimum", minProxies: 2, maxProxies: 3, trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, trustedCount: 1, wantErr: ErrTooFewTrustedProxies}, + {name: "above maximum", minProxies: 1, maxProxies: 2, trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, trustedCount: 3, wantErr: ErrTooManyTrustedProxies}, + {name: "no minimum requirement", minProxies: 0, maxProxies: 3, trustedCIDRs: []netip.Prefix{}, trustedCount: 0}, + {name: "no maximum limit", minProxies: 1, maxProxies: 0, trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, trustedCount: 100}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateProxyCountPolicy(tt.trustedCount, proxyPolicy{ + TrustedProxyCIDRs: tt.trustedCIDRs, + MinTrustedProxies: tt.minProxies, + MaxTrustedProxies: tt.maxProxies, + }) + if tt.wantErr == nil { + if err != nil { + t.Fatalf("validateProxyCountPolicy() error = %v, want nil", err) + } + return + } + + if !errors.Is(err, tt.wantErr) { + t.Fatalf("validateProxyCountPolicy() error = %v, want %v", err, tt.wantErr) + } + }) + } +} + +func TestAnalyzeChainRightmostNoCIDRs(t *testing.T) { + tests := []struct { + name string + parts []string + maxTrustedProxies int + wantClientIndex int + wantTrustedCount int + }{ + {name: "no max proxies", parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1"}, maxTrustedProxies: 0, wantClientIndex: 0, wantTrustedCount: 3}, + {name: "max 1 proxy", parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1"}, maxTrustedProxies: 1, wantClientIndex: 1, wantTrustedCount: 1}, + {name: "max 2 proxies", parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1"}, maxTrustedProxies: 2, wantClientIndex: 0, wantTrustedCount: 2}, + {name: "single IP", parts: []string{"1.1.1.1"}, maxTrustedProxies: 1, wantClientIndex: 0, wantTrustedCount: 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + analysis, _, err := analyzeChainRightmost(tt.parts, proxyPolicy{MaxTrustedProxies: tt.maxTrustedProxies}, true) + if err != nil { + t.Fatalf("analyzeChainRightmost() error = %v", err) + } + + if analysis.ClientIndex != tt.wantClientIndex { + t.Errorf("clientIndex = %d, want %d", analysis.ClientIndex, tt.wantClientIndex) + } + if analysis.TrustedCount != tt.wantTrustedCount { + t.Errorf("trustedCount = %d, want %d", analysis.TrustedCount, tt.wantTrustedCount) + } + }) + } +} + +func TestAnalyzeChainRightmostWithCIDRs(t *testing.T) { + policy := proxyPolicy{TrustedProxyCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}} + policy.TrustedProxyMatch = newPrefixMatcher(policy.TrustedProxyCIDRs) + + tests := []struct { + name string + parts []string + minProxies int + maxProxies int + wantClientIndex int + wantTrustedCount int + wantErr error + wantTrustedIndices []int + }{ + {name: "one trusted proxy at end", parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1"}, maxProxies: 2, wantClientIndex: 1, wantTrustedCount: 1, wantTrustedIndices: []int{2}}, + {name: "two trusted proxies at end", parts: []string{"1.1.1.1", "10.0.0.1", "10.0.0.2"}, maxProxies: 2, wantClientIndex: 0, wantTrustedCount: 2, wantTrustedIndices: []int{2, 1}}, + {name: "no trusted proxies allowed when minimum is zero", parts: []string{"1.1.1.1", "8.8.8.8"}, maxProxies: 2, wantClientIndex: 1, wantTrustedCount: 0, wantTrustedIndices: []int{}}, + {name: "no trusted proxies with minimum requirement", parts: []string{"1.1.1.1", "8.8.8.8"}, minProxies: 1, maxProxies: 2, wantClientIndex: 1, wantTrustedCount: 0, wantErr: ErrNoTrustedProxies, wantTrustedIndices: []int{}}, + {name: "too many trusted proxies", parts: []string{"1.1.1.1", "10.0.0.1", "10.0.0.2", "10.0.0.3"}, maxProxies: 2, wantClientIndex: 0, wantTrustedCount: 3, wantErr: ErrTooManyTrustedProxies, wantTrustedIndices: []int{3, 2, 1}}, + {name: "below min proxies", parts: []string{"1.1.1.1", "10.0.0.1"}, minProxies: 2, maxProxies: 3, wantClientIndex: 0, wantTrustedCount: 1, wantErr: ErrTooFewTrustedProxies}, + {name: "mixed trusted and untrusted", parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1", "198.51.100.2", "10.0.0.2"}, maxProxies: 3, wantClientIndex: 3, wantTrustedCount: 1, wantTrustedIndices: []int{4}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + active := policy + active.MinTrustedProxies = tt.minProxies + active.MaxTrustedProxies = tt.maxProxies + + analysis, _, err := analyzeChainRightmost(tt.parts, active, true) + if tt.wantErr == nil { + if err != nil { + t.Fatalf("analyzeChainRightmost() error = %v, want nil", err) + } + } else if !errors.Is(err, tt.wantErr) { + t.Fatalf("analyzeChainRightmost() error = %v, want %v", err, tt.wantErr) + } + + if analysis.ClientIndex != tt.wantClientIndex { + t.Errorf("clientIndex = %d, want %d", analysis.ClientIndex, tt.wantClientIndex) + } + if analysis.TrustedCount != tt.wantTrustedCount { + t.Errorf("trustedCount = %d, want %d", analysis.TrustedCount, tt.wantTrustedCount) + } + if tt.wantErr == nil && tt.wantTrustedIndices != nil { + if len(analysis.TrustedIndices) != len(tt.wantTrustedIndices) { + t.Fatalf("trustedIndices length = %d, want %d", len(analysis.TrustedIndices), len(tt.wantTrustedIndices)) + } + for i, idx := range tt.wantTrustedIndices { + if analysis.TrustedIndices[i] != idx { + t.Fatalf("trustedIndices[%d] = %d, want %d", i, analysis.TrustedIndices[i], idx) + } + } + } + }) + } +} + +func TestAnalyzeChainLeftmostWithCIDRs(t *testing.T) { + policy := proxyPolicy{TrustedProxyCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}} + policy.TrustedProxyMatch = newPrefixMatcher(policy.TrustedProxyCIDRs) + + tests := []struct { + name string + parts []string + minProxies int + maxProxies int + wantClientIndex int + wantTrustedCount int + wantErr error + }{ + {name: "one trusted proxy at end", parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1"}, maxProxies: 2, wantClientIndex: 0, wantTrustedCount: 1}, + {name: "two trusted proxies at end", parts: []string{"1.1.1.1", "10.0.0.1", "10.0.0.2"}, maxProxies: 2, wantClientIndex: 0, wantTrustedCount: 2}, + {name: "all trusted proxies", parts: []string{"10.0.0.1", "10.0.0.2", "10.0.0.3"}, maxProxies: 3, wantClientIndex: 0, wantTrustedCount: 3}, + {name: "no trusted proxies allowed when minimum is zero", parts: []string{"1.1.1.1", "8.8.8.8"}, maxProxies: 2, wantClientIndex: 0, wantTrustedCount: 0}, + {name: "no trusted proxies with minimum requirement", parts: []string{"1.1.1.1", "8.8.8.8"}, minProxies: 1, maxProxies: 2, wantClientIndex: 0, wantTrustedCount: 0, wantErr: ErrNoTrustedProxies}, + {name: "too many trusted proxies", parts: []string{"1.1.1.1", "10.0.0.1", "10.0.0.2", "10.0.0.3"}, maxProxies: 2, wantClientIndex: 0, wantTrustedCount: 3, wantErr: ErrTooManyTrustedProxies}, + {name: "below min proxies", parts: []string{"1.1.1.1", "10.0.0.1"}, minProxies: 2, maxProxies: 3, wantClientIndex: 0, wantTrustedCount: 1, wantErr: ErrTooFewTrustedProxies}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + active := policy + active.MinTrustedProxies = tt.minProxies + active.MaxTrustedProxies = tt.maxProxies + + analysis, _, err := analyzeChainLeftmost(tt.parts, active, true) + if tt.wantErr == nil { + if err != nil { + t.Fatalf("analyzeChainLeftmost() error = %v, want nil", err) + } + } else if !errors.Is(err, tt.wantErr) { + t.Fatalf("analyzeChainLeftmost() error = %v, want %v", err, tt.wantErr) + } + + if analysis.ClientIndex != tt.wantClientIndex { + t.Errorf("clientIndex = %d, want %d", analysis.ClientIndex, tt.wantClientIndex) + } + if analysis.TrustedCount != tt.wantTrustedCount { + t.Errorf("trustedCount = %d, want %d", analysis.TrustedCount, tt.wantTrustedCount) + } + }) + } +} diff --git a/trust_client_ip.go b/trust_client_ip.go new file mode 100644 index 0000000..02bcd9f --- /dev/null +++ b/trust_client_ip.go @@ -0,0 +1,145 @@ +package clientip + +import "net/netip" + +type clientIPPolicy struct { + AllowPrivateIPs bool + AllowReservedClientPrefixes []netip.Prefix +} + +type clientIPDisposition int + +const ( + clientIPInvalid clientIPDisposition = iota + clientIPValid + clientIPReserved + clientIPPrivate +) + +var ( + reservedClientIPv4Prefixes = []netip.Prefix{ + mustParsePrefix("0.0.0.0/8"), + mustParsePrefix("100.64.0.0/10"), + mustParsePrefix("192.0.0.0/24"), + mustParsePrefix("192.0.2.0/24"), + mustParsePrefix("198.18.0.0/15"), + mustParsePrefix("198.51.100.0/24"), + mustParsePrefix("203.0.113.0/24"), + mustParsePrefix("240.0.0.0/4"), + } + + reservedClientIPv6Prefixes = []netip.Prefix{ + mustParsePrefix("64:ff9b::/96"), + mustParsePrefix("64:ff9b:1::/48"), + mustParsePrefix("100::/64"), + mustParsePrefix("2001:2::/48"), + mustParsePrefix("2001:db8::/32"), + mustParsePrefix("2001:20::/28"), + } +) + +// ipv4SpecialFirstOctet marks first octets that appear in any special IPv4 range +// (private, reserved, loopback, link-local, multicast). If the first octet is not +// marked, the address is guaranteed to be a valid public IPv4 — allowing us to +// skip all individual checks in evaluateClientIP. +var ipv4SpecialFirstOctet [256]bool + +func init() { + // Every IPv4 prefix that evaluateClientIP may treat as non-public. + // This must cover the same ranges as the checks in evaluateClientIP: + // IsLoopback, IsLinkLocalUnicast, IsMulticast, IsUnspecified, IsPrivate, + // plus all entries in reservedClientIPv4Prefixes. + specialRanges := append([]netip.Prefix{ + mustParsePrefix("0.0.0.0/8"), // IsUnspecified + mustParsePrefix("10.0.0.0/8"), // IsPrivate + mustParsePrefix("127.0.0.0/8"), // IsLoopback + mustParsePrefix("169.254.0.0/16"), // IsLinkLocalUnicast + mustParsePrefix("172.16.0.0/12"), // IsPrivate + mustParsePrefix("192.168.0.0/16"), // IsPrivate + mustParsePrefix("224.0.0.0/3"), // IsMulticast + future reserved (224.0.0.0–255.255.255.255) + }, reservedClientIPv4Prefixes...) + + for _, prefix := range specialRanges { + markIPv4SpecialOctets(prefix) + } +} + +// markIPv4SpecialOctets marks all first octets covered by prefix in the lookup table. +func markIPv4SpecialOctets(prefix netip.Prefix) { + first := prefix.Addr().As4()[0] + bits := prefix.Bits() + if bits >= 8 { + ipv4SpecialFirstOctet[first] = true + return + } + + // Prefix wider than /8 — covers multiple first octets. + count := 1 << (8 - bits) + for i := 0; i < count; i++ { + ipv4SpecialFirstOctet[int(first)+i] = true + } +} + +func evaluateClientIP(ip netip.Addr, policy clientIPPolicy) clientIPDisposition { + if !ip.IsValid() { + return clientIPInvalid + } + + // Fast path: IPv4 with first octet not in any special range is always + // a valid public address. This avoids 6+ sequential method calls for the + // common case. + if ip.Is4() && !ipv4SpecialFirstOctet[ip.As4()[0]] { + return clientIPValid + } + + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsMulticast() || ip.IsUnspecified() { + return clientIPInvalid + } + + if isReservedIP(ip) && !isAllowlistedReservedClientIP(ip, policy.AllowReservedClientPrefixes) { + return clientIPReserved + } + + if !policy.AllowPrivateIPs && ip.IsPrivate() { + return clientIPPrivate + } + + return clientIPValid +} + +func isReservedIP(ip netip.Addr) bool { + if !ip.IsValid() { + return false + } + + ip = normalizeIP(ip) + + prefixes := reservedClientIPv6Prefixes + if ip.Is4() { + prefixes = reservedClientIPv4Prefixes + } + + for _, prefix := range prefixes { + if prefix.Contains(ip) { + return true + } + } + + return false +} + +func isAllowlistedReservedClientIP(ip netip.Addr, allowlist []netip.Prefix) bool { + if len(allowlist) == 0 || !ip.IsValid() { + return false + } + + ip = normalizeIP(ip) + + for _, prefix := range allowlist { + if prefix.Contains(ip) { + return true + } + } + + return false +} diff --git a/trust_client_ip_test.go b/trust_client_ip_test.go new file mode 100644 index 0000000..2b0c186 --- /dev/null +++ b/trust_client_ip_test.go @@ -0,0 +1,141 @@ +package clientip + +import ( + "net/netip" + "testing" +) + +func TestEvaluateClientIP(t *testing.T) { + tests := []struct { + name string + ip string + allowPrivate bool + want clientIPDisposition + }{ + {name: "public IPv4", ip: "1.1.1.1", want: clientIPValid}, + {name: "public IPv6", ip: "2606:4700:4700::1", want: clientIPValid}, + {name: "loopback IPv4", ip: "127.0.0.1", want: clientIPInvalid}, + {name: "loopback IPv6", ip: "::1", want: clientIPInvalid}, + {name: "link-local IPv4", ip: "169.254.1.1", want: clientIPInvalid}, + {name: "link-local IPv6", ip: "fe80::1", want: clientIPInvalid}, + {name: "multicast IPv4", ip: "224.0.0.1", want: clientIPInvalid}, + {name: "multicast IPv6", ip: "ff02::1", want: clientIPInvalid}, + {name: "unspecified IPv4", ip: "0.0.0.0", want: clientIPInvalid}, + {name: "unspecified IPv6", ip: "::", want: clientIPInvalid}, + {name: "private IPv4 rejected", ip: "192.168.1.1", want: clientIPPrivate}, + {name: "private IPv4 allowed", ip: "192.168.1.1", allowPrivate: true, want: clientIPValid}, + {name: "10.x private rejected", ip: "10.0.0.1", want: clientIPPrivate}, + {name: "10.x private allowed", ip: "10.0.0.1", allowPrivate: true, want: clientIPValid}, + {name: "172.16.x private rejected", ip: "172.16.0.1", want: clientIPPrivate}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := evaluateClientIP(netip.MustParseAddr(tt.ip), clientIPPolicy{AllowPrivateIPs: tt.allowPrivate}) + if got != tt.want { + t.Errorf("evaluateClientIP(%s) = %v, want %v", tt.ip, got, tt.want) + } + }) + } + + if got := evaluateClientIP(netip.Addr{}, clientIPPolicy{}); got != clientIPInvalid { + t.Fatalf("evaluateClientIP(invalid) = %v, want %v", got, clientIPInvalid) + } +} + +func TestIsReservedIP(t *testing.T) { + tests := []struct { + name string + ip string + reserved bool + }{ + {name: "CGN start", ip: "100.64.0.0", reserved: true}, + {name: "CGN middle", ip: "100.100.100.100", reserved: true}, + {name: "CGN end", ip: "100.127.255.255", reserved: true}, + {name: "Not CGN - before", ip: "100.63.255.255", reserved: false}, + {name: "Not CGN - after", ip: "100.128.0.0", reserved: false}, + {name: "this-network reserved", ip: "0.1.2.3", reserved: true}, + {name: "IETF protocol assignments reserved", ip: "192.0.0.8", reserved: true}, + {name: "benchmarking reserved", ip: "198.18.0.1", reserved: true}, + {name: "TEST-NET-1", ip: "192.0.2.1", reserved: true}, + {name: "TEST-NET-2", ip: "198.51.100.1", reserved: true}, + {name: "TEST-NET-3", ip: "203.0.113.1", reserved: true}, + {name: "future-use IPv4 reserved", ip: "240.0.0.1", reserved: true}, + {name: "IPv6 doc prefix", ip: "2001:db8::1", reserved: true}, + {name: "IPv6 benchmarking prefix", ip: "2001:2::1", reserved: true}, + {name: "IPv6 ORCHIDv2 prefix", ip: "2001:20::1", reserved: true}, + {name: "IPv6 NAT64 well-known prefix", ip: "64:ff9b::808:808", reserved: true}, + {name: "IPv6 NAT64 local-use prefix", ip: "64:ff9b:1::1", reserved: true}, + {name: "IPv6 discard-only prefix", ip: "100::1", reserved: true}, + {name: "Not IPv6 doc - different prefix", ip: "2001:db9::1", reserved: false}, + {name: "Not ORCHIDv2 - outside prefix", ip: "2001:30::1", reserved: false}, + {name: "Public IPv4", ip: "8.8.8.8", reserved: false}, + {name: "Private IPv4", ip: "192.168.1.1", reserved: false}, + {name: "Public IPv6", ip: "2001:4860:4860::8888", reserved: false}, + {name: "IPv4-mapped reserved IPv6", ip: "::ffff:198.51.100.1", reserved: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isReservedIP(netip.MustParseAddr(tt.ip)) + if got != tt.reserved { + t.Errorf("isReservedIP(%s) = %v, want %v", tt.ip, got, tt.reserved) + } + }) + } +} + +func TestEvaluateClientIPReservedRanges(t *testing.T) { + tests := []struct { + name string + ip string + want clientIPDisposition + }{ + {name: "CGN rejected", ip: "100.64.0.1", want: clientIPReserved}, + {name: "benchmarking range rejected", ip: "198.18.1.1", want: clientIPReserved}, + {name: "future-use IPv4 rejected", ip: "240.0.0.2", want: clientIPReserved}, + {name: "TEST-NET-1 rejected", ip: "192.0.2.1", want: clientIPReserved}, + {name: "TEST-NET-2 rejected", ip: "198.51.100.1", want: clientIPReserved}, + {name: "TEST-NET-3 rejected", ip: "203.0.113.1", want: clientIPReserved}, + {name: "IPv6 doc rejected", ip: "2001:db8::1", want: clientIPReserved}, + {name: "IPv6 benchmarking rejected", ip: "2001:2::1", want: clientIPReserved}, + {name: "IPv6 NAT64 well-known rejected", ip: "64:ff9b::808:808", want: clientIPReserved}, + {name: "Private allowed when configured", ip: "192.168.1.1", want: clientIPValid}, + {name: "Public allowed", ip: "8.8.8.8", want: clientIPValid}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := evaluateClientIP(netip.MustParseAddr(tt.ip), clientIPPolicy{AllowPrivateIPs: true}) + if got != tt.want { + t.Errorf("evaluateClientIP(%s) = %v, want %v", tt.ip, got, tt.want) + } + }) + } +} + +func TestEvaluateClientIPWithAllowedReservedClientPrefixes(t *testing.T) { + policy := clientIPPolicy{AllowReservedClientPrefixes: []netip.Prefix{netip.MustParsePrefix("100.64.0.0/10"), netip.MustParsePrefix("2001:db8::/32")}} + + tests := []struct { + name string + ip string + want clientIPDisposition + }{ + {name: "allowlisted reserved IPv4", ip: "100.64.0.1", want: clientIPValid}, + {name: "non-allowlisted reserved IPv4", ip: "198.51.100.1", want: clientIPReserved}, + {name: "allowlisted reserved IPv6", ip: "2001:db8::1", want: clientIPValid}, + {name: "non-allowlisted reserved IPv6", ip: "64:ff9b::808:808", want: clientIPReserved}, + {name: "private remains rejected", ip: "192.168.1.1", want: clientIPPrivate}, + {name: "public remains allowed", ip: "8.8.8.8", want: clientIPValid}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := evaluateClientIP(netip.MustParseAddr(tt.ip), policy) + if got != tt.want { + t.Errorf("evaluateClientIP(%s) = %v, want %v", tt.ip, got, tt.want) + } + }) + } +} diff --git a/trusted_proxy_matcher.go b/trust_matcher.go similarity index 90% rename from trusted_proxy_matcher.go rename to trust_matcher.go index 1658fe7..1509340 100644 --- a/trusted_proxy_matcher.go +++ b/trust_matcher.go @@ -2,7 +2,7 @@ package clientip import "net/netip" -type trustedProxyMatcher struct { +type prefixMatcher struct { initialized bool ipv4Root *prefixTrieNode ipv6Root *prefixTrieNode @@ -13,8 +13,8 @@ type prefixTrieNode struct { terminal bool } -func buildTrustedProxyMatcher(prefixes []netip.Prefix) trustedProxyMatcher { - matcher := trustedProxyMatcher{} +func newPrefixMatcher(prefixes []netip.Prefix) prefixMatcher { + matcher := prefixMatcher{} if len(prefixes) == 0 { return matcher } @@ -56,27 +56,7 @@ func buildTrustedProxyMatcher(prefixes []netip.Prefix) trustedProxyMatcher { return matcher } -func insertPrefix(root *prefixTrieNode, addr []byte, bits int) { - node := root - if bits == 0 { - node.terminal = true - return - } - - for bitIndex := 0; bitIndex < bits; bitIndex++ { - bit := addrBit(addr, bitIndex) - child := node.children[bit] - if child == nil { - child = &prefixTrieNode{} - node.children[bit] = child - } - node = child - } - - node.terminal = true -} - -func (m trustedProxyMatcher) contains(ip netip.Addr) bool { +func (m prefixMatcher) contains(ip netip.Addr) bool { if !m.initialized || !ip.IsValid() { return false } @@ -98,6 +78,26 @@ func (m trustedProxyMatcher) contains(ip netip.Addr) bool { return trieContains(m.ipv6Root, bytes[:]) } +func insertPrefix(root *prefixTrieNode, addr []byte, bits int) { + node := root + if bits == 0 { + node.terminal = true + return + } + + for bitIndex := 0; bitIndex < bits; bitIndex++ { + bit := addrBit(addr, bitIndex) + child := node.children[bit] + if child == nil { + child = &prefixTrieNode{} + node.children[bit] = child + } + node = child + } + + node.terminal = true +} + func trieContains(root *prefixTrieNode, addr []byte) bool { node := root if node == nil { diff --git a/trusted_proxy_matcher_test.go b/trust_matcher_test.go similarity index 54% rename from trusted_proxy_matcher_test.go rename to trust_matcher_test.go index 231e734..1759c3d 100644 --- a/trusted_proxy_matcher_test.go +++ b/trust_matcher_test.go @@ -5,8 +5,8 @@ import ( "testing" ) -func TestTrustedProxyMatcher_Contains(t *testing.T) { - matcher := buildTrustedProxyMatcher([]netip.Prefix{ +func TestMatcherContains(t *testing.T) { + matcher := newPrefixMatcher([]netip.Prefix{ netip.MustParsePrefix("10.0.0.0/8"), netip.MustParsePrefix("2001:db8::/32"), }) @@ -31,8 +31,8 @@ func TestTrustedProxyMatcher_Contains(t *testing.T) { } } -func TestTrustedProxyMatcher_ZeroPrefix(t *testing.T) { - v4Matcher := buildTrustedProxyMatcher([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}) +func TestMatcherZeroPrefix(t *testing.T) { + v4Matcher := newPrefixMatcher([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}) if !v4Matcher.contains(netip.MustParseAddr("8.8.8.8")) { t.Fatal("expected IPv4 matcher to trust all IPv4 addresses") } @@ -40,7 +40,7 @@ func TestTrustedProxyMatcher_ZeroPrefix(t *testing.T) { t.Fatal("expected IPv4 matcher to reject IPv6 addresses") } - v6Matcher := buildTrustedProxyMatcher([]netip.Prefix{netip.MustParsePrefix("::/0")}) + v6Matcher := newPrefixMatcher([]netip.Prefix{netip.MustParsePrefix("::/0")}) if !v6Matcher.contains(netip.MustParseAddr("2001:4860:4860::8888")) { t.Fatal("expected IPv6 matcher to trust all IPv6 addresses") } @@ -49,39 +49,25 @@ func TestTrustedProxyMatcher_ZeroPrefix(t *testing.T) { } } -func TestIsTrustedProxy_UsesPrecomputedMatcher(t *testing.T) { - extractor, err := New(WithTrustedProxyPrefixes(netip.MustParsePrefix("10.0.0.0/8"))) - if err != nil { - t.Fatalf("New() error = %v", err) +func TestIsTrustedProxyUsesMatcher(t *testing.T) { + matcher := newPrefixMatcher([]netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}) + if !matcher.initialized { + t.Fatal("expected matcher to be initialized") } - if !extractor.config.trustedProxyMatch.initialized { - t.Fatal("expected precomputed trusted proxy matcher to be initialized") - } - - if !extractor.isTrustedProxy(netip.MustParseAddr("10.12.1.3")) { + if !isTrustedProxy(netip.MustParseAddr("10.12.1.3"), matcher, []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}) { t.Fatal("expected address to be trusted") } - if extractor.isTrustedProxy(netip.MustParseAddr("8.8.8.8")) { + if isTrustedProxy(netip.MustParseAddr("8.8.8.8"), matcher, []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}) { t.Fatal("expected address to be untrusted") } } -func TestIsTrustedProxy_LinearFallbackWhenMatcherMissing(t *testing.T) { - extractor := &Extractor{ - config: &config{ - trustedProxyCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, - }, - } - - if extractor.config.trustedProxyMatch.initialized { - t.Fatal("expected matcher to be uninitialized for manual config") - } - - if !extractor.isTrustedProxy(netip.MustParseAddr("10.12.1.3")) { +func TestIsTrustedProxyLinearFallbackWhenMatcherMissing(t *testing.T) { + if !isTrustedProxy(netip.MustParseAddr("10.12.1.3"), prefixMatcher{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}) { t.Fatal("expected address to be trusted via linear fallback") } - if extractor.isTrustedProxy(netip.MustParseAddr("8.8.8.8")) { + if isTrustedProxy(netip.MustParseAddr("8.8.8.8"), prefixMatcher{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}) { t.Fatal("expected address to be untrusted via linear fallback") } } diff --git a/types.go b/types.go index a794f83..9f6c22e 100644 --- a/types.go +++ b/types.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/netip" - "strings" ) var ( @@ -187,7 +186,3 @@ func ParseCIDRs(cidrs ...string) ([]netip.Prefix, error) { } return prefixes, nil } - -func normalizeSourceName(headerName string) string { - return strings.ToLower(strings.ReplaceAll(headerName, "-", "_")) -} diff --git a/types_test.go b/types_test.go index b6563e8..11f4f40 100644 --- a/types_test.go +++ b/types_test.go @@ -193,6 +193,7 @@ func TestSource_StringAndCanonicalization(t *testing.T) { {name: "x-real-ip alias", got: HeaderSource("X_Real_IP"), want: SourceXRealIP, text: "x_real_ip"}, {name: "remote addr alias", got: HeaderSource("Remote-Addr"), want: SourceRemoteAddr, text: "remote_addr"}, {name: "custom header", got: HeaderSource("CF-Connecting-IP"), want: HeaderSource("cf-connecting-ip"), text: "cf_connecting_ip"}, + {name: "static fallback alias", got: HeaderSource("Static-Fallback"), want: SourceStaticFallback, text: "static_fallback"}, {name: "blank header invalid", got: HeaderSource(" "), want: Source{}, text: ""}, } @@ -209,6 +210,29 @@ func TestSource_StringAndCanonicalization(t *testing.T) { } } +func TestHeaderSource_String(t *testing.T) { + tests := []struct { + input string + want string + }{ + {input: "X-Forwarded-For", want: "x_forwarded_for"}, + {input: "Forwarded", want: "forwarded"}, + {input: "X-Real-IP", want: "x_real_ip"}, + {input: "CF-Connecting-IP", want: "cf_connecting_ip"}, + {input: "UPPERCASE-HEADER", want: "uppercase_header"}, + {input: "already_underscored", want: "already_underscored"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := HeaderSource(tt.input).String() + if got != tt.want { + t.Errorf("HeaderSource(%q).String() = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + func TestSource_TextAndJSONRoundTrip(t *testing.T) { original := HeaderSource("CF-Connecting-IP") @@ -287,15 +311,15 @@ func TestSource_BuiltinsIgnoreExportedValueMutation(t *testing.T) { t.Fatalf("HeaderSource(Remote-Addr) = %q, want %q", got, want) } - extractor := mustNewExtractor(t) + extractor := mustNewExtractor(t, DefaultConfig()) if diff := cmp.Diff([]Source{builtinSource(sourceRemoteAddr)}, extractor.config.sourcePriority); diff != "" { t.Fatalf("default source priority mismatch (-want +got):\n%s", diff) } - forwardedExtractor := mustNewExtractor(t, - WithTrustedLoopbackProxy(), - WithSourcePriority(HeaderSource("Forwarded")), - ) + forwardedConfig := DefaultConfig() + forwardedConfig.TrustedProxyPrefixes = LoopbackProxyPrefixes() + forwardedConfig.Sources = []Source{HeaderSource("Forwarded")} + forwardedExtractor := mustNewExtractor(t, forwardedConfig) if diff := cmp.Diff([]Source{builtinSource(sourceForwarded)}, forwardedExtractor.config.sourcePriority); diff != "" { t.Fatalf("canonicalized source priority mismatch (-want +got):\n%s", diff) } diff --git a/xff_parse.go b/xff_parse.go deleted file mode 100644 index fd6675a..0000000 --- a/xff_parse.go +++ /dev/null @@ -1,38 +0,0 @@ -package clientip - -func (e *Extractor) parseXFFValues(values []string) ([]string, error) { - if len(values) == 0 { - return nil, nil - } - - maxChainLength := e.config.maxChainLength - parts := make([]string, 0, e.chainPartsCapacity(values)) - for _, v := range values { - start := 0 - for i := 0; i <= len(v); i++ { - if i != len(v) && v[i] != ',' { - continue - } - - part := trimHTTPWhitespace(v[start:i]) - if part != "" { - if len(parts) >= maxChainLength { - e.config.metrics.RecordSecurityEvent(securityEventChainTooLong) - return nil, &ChainTooLongError{ - ExtractionError: ExtractionError{ - Err: ErrChainTooLong, - Source: builtinSource(sourceXForwardedFor), - }, - ChainLength: len(parts) + 1, - MaxLength: maxChainLength, - } - } - - parts = append(parts, part) - } - - start = i + 1 - } - } - return parts, nil -} diff --git a/xff_test.go b/xff_test.go deleted file mode 100644 index f98a3b2..0000000 --- a/xff_test.go +++ /dev/null @@ -1,771 +0,0 @@ -package clientip - -import ( - "errors" - "net/netip" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestParseXFFValues(t *testing.T) { - tests := []struct { - name string - values []string - want []string - }{ - { - name: "single value", - values: []string{"1.1.1.1"}, - want: []string{"1.1.1.1"}, - }, - { - name: "single value with multiple IPs", - values: []string{"1.1.1.1, 8.8.8.8"}, - want: []string{"1.1.1.1", "8.8.8.8"}, - }, - { - name: "multiple values combined", - values: []string{"1.1.1.1", "8.8.8.8"}, - want: []string{"1.1.1.1", "8.8.8.8"}, - }, - { - name: "whitespace trimmed", - values: []string{" 1.1.1.1 , 8.8.8.8 "}, - want: []string{"1.1.1.1", "8.8.8.8"}, - }, - { - name: "empty strings ignored", - values: []string{"1.1.1.1, , 8.8.8.8"}, - want: []string{"1.1.1.1", "8.8.8.8"}, - }, - { - name: "empty list", - values: []string{}, - want: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - extractor := mustNewExtractor(t) - got, err := extractor.parseXFFValues(tt.values) - if err != nil { - t.Fatalf("parseXFFValues() error = %v, want nil", err) - } - - if diff := cmp.Diff(tt.want, got); diff != "" { - t.Fatalf("parseXFFValues() mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestParseXFFValues_MaxChainLength(t *testing.T) { - extractor := mustNewExtractor(t, WithMaxChainLength(5)) - - values := []string{"1.1.1.1, 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5, 6.6.6.6, 7.7.7.7"} - _, err := extractor.parseXFFValues(values) - - if !errorContains(err, ErrChainTooLong) { - t.Fatalf("parseXFFValues() error = %v, want ErrChainTooLong", err) - } -} - -func TestParseXFFValues_PreservesWireOrderAcrossHeaderLines(t *testing.T) { - extractor := mustNewExtractor(t, WithMaxChainLength(10)) - - values := []string{ - "1.1.1.1, 8.8.8.8", - "9.9.9.9", - " 4.4.4.4 , 5.5.5.5 ", - } - - got, err := extractor.parseXFFValues(values) - if err != nil { - t.Fatalf("parseXFFValues() error = %v", err) - } - - want := []string{"1.1.1.1", "8.8.8.8", "9.9.9.9", "4.4.4.4", "5.5.5.5"} - if diff := cmp.Diff(want, got); diff != "" { - t.Fatalf("parseXFFValues() mismatch (-want +got):\n%s", diff) - } -} - -func TestParseXFFValues_MaxChainLength_AcrossHeaderLines(t *testing.T) { - extractor := mustNewExtractor(t, WithMaxChainLength(3)) - - _, err := extractor.parseXFFValues([]string{ - "1.1.1.1, 8.8.8.8", - "9.9.9.9", - "4.4.4.4", - }) - - if !errors.Is(err, ErrChainTooLong) { - t.Fatalf("parseXFFValues() error = %v, want ErrChainTooLong", err) - } -} - -func TestIsTrustedProxy(t *testing.T) { - cidrs := mustParseCIDRs(t, "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16") - - extractor := &Extractor{ - config: &config{ - trustedProxyCIDRs: cidrs, - }, - } - - tests := []struct { - name string - ip string - want bool - }{ - { - name: "10.x.x.x trusted", - ip: "10.0.0.1", - want: true, - }, - { - name: "172.16.x.x trusted", - ip: "172.16.0.1", - want: true, - }, - { - name: "192.168.x.x trusted", - ip: "192.168.1.1", - want: true, - }, - { - name: "public IP not trusted", - ip: "1.1.1.1", - want: false, - }, - { - name: "invalid IP not trusted", - ip: "invalid", - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ip := parseIP(tt.ip) - got := extractor.isTrustedProxy(ip) - if got != tt.want { - t.Errorf("isTrustedProxy(%s) = %v, want %v", tt.ip, got, tt.want) - } - }) - } -} - -func TestValidateProxyCount(t *testing.T) { - tests := []struct { - name string - minProxies int - maxProxies int - trustedCIDRs []netip.Prefix - trustedCount int - wantErr error - }{ - { - name: "within range", - minProxies: 1, - maxProxies: 3, - trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, - trustedCount: 2, - wantErr: nil, - }, - { - name: "at minimum", - minProxies: 1, - maxProxies: 3, - trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, - trustedCount: 1, - wantErr: nil, - }, - { - name: "at maximum", - minProxies: 1, - maxProxies: 3, - trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, - trustedCount: 3, - wantErr: nil, - }, - { - name: "no trusted proxies allowed when minimum is zero", - minProxies: 0, - maxProxies: 3, - trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, - trustedCount: 0, - wantErr: nil, - }, - { - name: "no trusted proxies with minimum requirement", - minProxies: 1, - maxProxies: 3, - trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, - trustedCount: 0, - wantErr: ErrNoTrustedProxies, - }, - { - name: "below minimum", - minProxies: 2, - maxProxies: 3, - trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, - trustedCount: 1, - wantErr: ErrTooFewTrustedProxies, - }, - { - name: "above maximum", - minProxies: 1, - maxProxies: 2, - trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, - trustedCount: 3, - wantErr: ErrTooManyTrustedProxies, - }, - { - name: "no minimum requirement", - minProxies: 0, - maxProxies: 3, - trustedCIDRs: []netip.Prefix{}, - trustedCount: 0, - wantErr: nil, - }, - { - name: "no maximum limit", - minProxies: 1, - maxProxies: 0, - trustedCIDRs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, - trustedCount: 100, - wantErr: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - extractor := &Extractor{ - config: &config{ - minTrustedProxies: tt.minProxies, - maxTrustedProxies: tt.maxProxies, - trustedProxyCIDRs: tt.trustedCIDRs, - metrics: noopMetrics{}, - }, - } - - err := extractor.validateProxyCount(tt.trustedCount) - if tt.wantErr == nil { - if err != nil { - t.Errorf("validateProxyCount() error = %v, want nil", err) - } - return - } - - if !errorContains(err, tt.wantErr) { - t.Errorf("validateProxyCount() error = %v, want %v", err, tt.wantErr) - } - }) - } -} - -func TestAnalyzeChainRightmost_NoCIDRs(t *testing.T) { - tests := []struct { - name string - parts []string - maxTrustedProxies int - wantClientIndex int - wantTrustedCount int - }{ - { - name: "no max proxies", - parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1"}, - maxTrustedProxies: 0, - wantClientIndex: 0, - wantTrustedCount: 3, - }, - { - name: "max 1 proxy", - parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1"}, - maxTrustedProxies: 1, - wantClientIndex: 1, - wantTrustedCount: 1, - }, - { - name: "max 2 proxies", - parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1"}, - maxTrustedProxies: 2, - wantClientIndex: 0, - wantTrustedCount: 2, - }, - { - name: "single IP", - parts: []string{"1.1.1.1"}, - maxTrustedProxies: 1, - wantClientIndex: 0, - wantTrustedCount: 1, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - extractor := &Extractor{ - config: &config{ - trustedProxyCIDRs: []netip.Prefix{}, - maxTrustedProxies: tt.maxTrustedProxies, - metrics: noopMetrics{}, - }, - } - - analysis, err := extractor.analyzeChainRightmost(tt.parts) - if err != nil { - t.Fatalf("analyzeChainRightmost() error = %v", err) - } - - if analysis.clientIndex != tt.wantClientIndex { - t.Errorf("clientIndex = %d, want %d", analysis.clientIndex, tt.wantClientIndex) - } - - if analysis.trustedCount != tt.wantTrustedCount { - t.Errorf("trustedCount = %d, want %d", analysis.trustedCount, tt.wantTrustedCount) - } - }) - } -} - -func TestAnalyzeChainRightmost_WithCIDRs(t *testing.T) { - cidrs := mustParseCIDRs(t, "10.0.0.0/8") - - tests := []struct { - name string - parts []string - minProxies int - maxProxies int - wantClientIndex int - wantTrustedCount int - wantErr error - wantTrustedIndices []int - }{ - { - name: "one trusted proxy at end", - parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1"}, - minProxies: 0, - maxProxies: 2, - wantClientIndex: 1, - wantTrustedCount: 1, - wantTrustedIndices: []int{2}, - }, - { - name: "two trusted proxies at end", - parts: []string{"1.1.1.1", "10.0.0.1", "10.0.0.2"}, - minProxies: 0, - maxProxies: 2, - wantClientIndex: 0, - wantTrustedCount: 2, - wantTrustedIndices: []int{2, 1}, - }, - { - name: "no trusted proxies allowed when minimum is zero", - parts: []string{"1.1.1.1", "8.8.8.8"}, - minProxies: 0, - maxProxies: 2, - wantClientIndex: 1, - wantTrustedCount: 0, - wantTrustedIndices: []int{}, - }, - { - name: "no trusted proxies with minimum requirement", - parts: []string{"1.1.1.1", "8.8.8.8"}, - minProxies: 1, - maxProxies: 2, - wantClientIndex: 1, - wantTrustedCount: 0, - wantErr: ErrNoTrustedProxies, - wantTrustedIndices: []int{}, - }, - { - name: "too many trusted proxies", - parts: []string{"1.1.1.1", "10.0.0.1", "10.0.0.2", "10.0.0.3"}, - minProxies: 0, - maxProxies: 2, - wantClientIndex: 0, - wantTrustedCount: 3, - wantErr: ErrTooManyTrustedProxies, - wantTrustedIndices: []int{3, 2, 1}, - }, - { - name: "below min proxies", - parts: []string{"1.1.1.1", "10.0.0.1"}, - minProxies: 2, - maxProxies: 3, - wantErr: ErrTooFewTrustedProxies, - wantClientIndex: 0, - wantTrustedCount: 1, - }, - { - name: "mixed trusted and untrusted", - parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1", "198.51.100.2", "10.0.0.2"}, - minProxies: 0, - maxProxies: 3, - wantClientIndex: 3, - wantTrustedCount: 1, - wantTrustedIndices: []int{4}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - extractor := &Extractor{ - config: &config{ - trustedProxyCIDRs: cidrs, - minTrustedProxies: tt.minProxies, - maxTrustedProxies: tt.maxProxies, - metrics: noopMetrics{}, - }, - } - - analysis, err := extractor.analyzeChainRightmost(tt.parts) - - if tt.wantErr == nil { - if err != nil { - t.Fatalf("analyzeChainRightmost() error = %v, want nil", err) - } - } else if !errorContains(err, tt.wantErr) { - t.Fatalf("analyzeChainRightmost() error = %v, want %v", err, tt.wantErr) - } - - if analysis.clientIndex != tt.wantClientIndex { - t.Errorf("clientIndex = %d, want %d", analysis.clientIndex, tt.wantClientIndex) - } - - if analysis.trustedCount != tt.wantTrustedCount { - t.Errorf("trustedCount = %d, want %d", analysis.trustedCount, tt.wantTrustedCount) - } - - if tt.wantErr == nil && tt.wantTrustedIndices != nil { - if len(analysis.trustedIndices) != len(tt.wantTrustedIndices) { - t.Errorf("trustedIndices length = %d, want %d", len(analysis.trustedIndices), len(tt.wantTrustedIndices)) - } else { - for i, idx := range tt.wantTrustedIndices { - if analysis.trustedIndices[i] != idx { - t.Errorf("trustedIndices[%d] = %d, want %d", i, analysis.trustedIndices[i], idx) - } - } - } - } - }) - } -} - -func TestAnalyzeChainLeftmost_WithCIDRs(t *testing.T) { - cidrs := mustParseCIDRs(t, "10.0.0.0/8") - - tests := []struct { - name string - parts []string - minProxies int - maxProxies int - wantClientIndex int - wantTrustedCount int - wantErr error - }{ - { - name: "one trusted proxy at end", - parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1"}, - minProxies: 0, - maxProxies: 2, - wantClientIndex: 0, - wantTrustedCount: 1, - }, - { - name: "two trusted proxies at end", - parts: []string{"1.1.1.1", "10.0.0.1", "10.0.0.2"}, - minProxies: 0, - maxProxies: 2, - wantClientIndex: 0, - wantTrustedCount: 2, - }, - { - name: "all trusted proxies", - parts: []string{"10.0.0.1", "10.0.0.2", "10.0.0.3"}, - minProxies: 0, - maxProxies: 3, - wantClientIndex: 0, - wantTrustedCount: 3, - }, - { - name: "no trusted proxies allowed when minimum is zero", - parts: []string{"1.1.1.1", "8.8.8.8"}, - minProxies: 0, - maxProxies: 2, - wantClientIndex: 0, - wantTrustedCount: 0, - }, - { - name: "no trusted proxies with minimum requirement", - parts: []string{"1.1.1.1", "8.8.8.8"}, - minProxies: 1, - maxProxies: 2, - wantClientIndex: 0, - wantTrustedCount: 0, - wantErr: ErrNoTrustedProxies, - }, - { - name: "too many trusted proxies", - parts: []string{"1.1.1.1", "10.0.0.1", "10.0.0.2", "10.0.0.3"}, - minProxies: 0, - maxProxies: 2, - wantClientIndex: 0, - wantTrustedCount: 3, - wantErr: ErrTooManyTrustedProxies, - }, - { - name: "below min proxies", - parts: []string{"1.1.1.1", "10.0.0.1"}, - minProxies: 2, - maxProxies: 3, - wantClientIndex: 0, - wantTrustedCount: 1, - wantErr: ErrTooFewTrustedProxies, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - extractor := &Extractor{ - config: &config{ - trustedProxyCIDRs: cidrs, - minTrustedProxies: tt.minProxies, - maxTrustedProxies: tt.maxProxies, - metrics: noopMetrics{}, - }, - } - - analysis, err := extractor.analyzeChainLeftmost(tt.parts) - - if tt.wantErr == nil { - if err != nil { - t.Fatalf("analyzeChainLeftmost() error = %v, want nil", err) - } - } else if !errorContains(err, tt.wantErr) { - t.Fatalf("analyzeChainLeftmost() error = %v, want %v", err, tt.wantErr) - } - - if analysis.clientIndex != tt.wantClientIndex { - t.Errorf("clientIndex = %d, want %d", analysis.clientIndex, tt.wantClientIndex) - } - - if analysis.trustedCount != tt.wantTrustedCount { - t.Errorf("trustedCount = %d, want %d", analysis.trustedCount, tt.wantTrustedCount) - } - }) - } -} - -func TestSelectLeftmostUntrustedIP(t *testing.T) { - cidrs := mustParseCIDRs(t, "10.0.0.0/8") - - tests := []struct { - name string - parts []string - trustedProxiesFromRight int - wantIndex int - }{ - { - name: "first IP untrusted", - parts: []string{"1.1.1.1", "8.8.8.8", "10.0.0.1"}, - trustedProxiesFromRight: 1, - wantIndex: 0, - }, - { - name: "second IP untrusted", - parts: []string{"10.0.0.1", "1.1.1.1", "10.0.0.2"}, - trustedProxiesFromRight: 1, - wantIndex: 1, - }, - { - name: "all IPs trusted", - parts: []string{"10.0.0.1", "10.0.0.2", "10.0.0.3"}, - trustedProxiesFromRight: 3, - wantIndex: 0, - }, - { - name: "no trusted proxies", - parts: []string{"1.1.1.1", "8.8.8.8"}, - trustedProxiesFromRight: 0, - wantIndex: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - extractor := &Extractor{ - config: &config{ - trustedProxyCIDRs: cidrs, - }, - } - - got := extractor.selectLeftmostUntrustedIP(tt.parts, tt.trustedProxiesFromRight) - if got != tt.wantIndex { - t.Errorf("selectLeftmostUntrustedIP() = %d, want %d", got, tt.wantIndex) - } - }) - } -} - -func TestIsPlausibleClientIP(t *testing.T) { - tests := []struct { - name string - ip string - allowPrivate bool - wantPlausible bool - }{ - {name: "public IPv4", ip: "1.1.1.1", allowPrivate: false, wantPlausible: true}, - {name: "public IPv6", ip: "2606:4700:4700::1", allowPrivate: false, wantPlausible: true}, - {name: "loopback IPv4", ip: "127.0.0.1", allowPrivate: false, wantPlausible: false}, - {name: "loopback IPv6", ip: "::1", allowPrivate: false, wantPlausible: false}, - {name: "link-local IPv4", ip: "169.254.1.1", allowPrivate: false, wantPlausible: false}, - {name: "link-local IPv6", ip: "fe80::1", allowPrivate: false, wantPlausible: false}, - {name: "multicast IPv4", ip: "224.0.0.1", allowPrivate: false, wantPlausible: false}, - {name: "multicast IPv6", ip: "ff02::1", allowPrivate: false, wantPlausible: false}, - {name: "unspecified IPv4", ip: "0.0.0.0", allowPrivate: false, wantPlausible: false}, - {name: "unspecified IPv6", ip: "::", allowPrivate: false, wantPlausible: false}, - {name: "private IPv4 rejected", ip: "192.168.1.1", allowPrivate: false, wantPlausible: false}, - {name: "private IPv4 allowed", ip: "192.168.1.1", allowPrivate: true, wantPlausible: true}, - {name: "10.x private rejected", ip: "10.0.0.1", allowPrivate: false, wantPlausible: false}, - {name: "10.x private allowed", ip: "10.0.0.1", allowPrivate: true, wantPlausible: true}, - {name: "172.16.x private rejected", ip: "172.16.0.1", allowPrivate: false, wantPlausible: false}, - {name: "invalid IP", ip: "invalid", allowPrivate: false, wantPlausible: false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - extractor := &Extractor{ - config: &config{ - allowPrivateIPs: tt.allowPrivate, - metrics: noopMetrics{}, - }, - } - - ip := parseIP(tt.ip) - got := extractor.isPlausibleClientIP(ip) - - if got != tt.wantPlausible { - t.Errorf("isPlausibleClientIP(%s) = %v, want %v", tt.ip, got, tt.wantPlausible) - } - }) - } -} - -func TestIsReservedIP(t *testing.T) { - tests := []struct { - name string - ip string - reserved bool - }{ - {name: "CGN start", ip: "100.64.0.0", reserved: true}, - {name: "CGN middle", ip: "100.100.100.100", reserved: true}, - {name: "CGN end", ip: "100.127.255.255", reserved: true}, - {name: "Not CGN - before", ip: "100.63.255.255", reserved: false}, - {name: "Not CGN - after", ip: "100.128.0.0", reserved: false}, - {name: "this-network reserved", ip: "0.1.2.3", reserved: true}, - {name: "IETF protocol assignments reserved", ip: "192.0.0.8", reserved: true}, - {name: "benchmarking reserved", ip: "198.18.0.1", reserved: true}, - {name: "TEST-NET-1", ip: "192.0.2.1", reserved: true}, - {name: "TEST-NET-2", ip: "198.51.100.1", reserved: true}, - {name: "TEST-NET-3", ip: "203.0.113.1", reserved: true}, - {name: "future-use IPv4 reserved", ip: "240.0.0.1", reserved: true}, - {name: "IPv6 doc prefix", ip: "2001:db8::1", reserved: true}, - {name: "IPv6 benchmarking prefix", ip: "2001:2::1", reserved: true}, - {name: "IPv6 ORCHIDv2 prefix", ip: "2001:20::1", reserved: true}, - {name: "IPv6 NAT64 well-known prefix", ip: "64:ff9b::808:808", reserved: true}, - {name: "IPv6 NAT64 local-use prefix", ip: "64:ff9b:1::1", reserved: true}, - {name: "IPv6 discard-only prefix", ip: "100::1", reserved: true}, - {name: "Not IPv6 doc - different prefix", ip: "2001:db9::1", reserved: false}, - {name: "Not ORCHIDv2 - outside prefix", ip: "2001:30::1", reserved: false}, - {name: "Public IPv4", ip: "8.8.8.8", reserved: false}, - {name: "Private IPv4", ip: "192.168.1.1", reserved: false}, - {name: "Public IPv6", ip: "2001:4860:4860::8888", reserved: false}, - {name: "IPv4-mapped reserved IPv6", ip: "::ffff:198.51.100.1", reserved: true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ip := netip.MustParseAddr(tt.ip) - got := isReservedIP(ip) - if got != tt.reserved { - t.Errorf("isReservedIP(%s) = %v, want %v", tt.ip, got, tt.reserved) - } - }) - } -} - -func TestIsPlausibleClientIP_ReservedRanges(t *testing.T) { - extractor := &Extractor{ - config: &config{ - allowPrivateIPs: true, // Allow private but still reject reserved - metrics: noopMetrics{}, - }, - } - - tests := []struct { - name string - ip string - wantOk bool - }{ - {name: "CGN rejected", ip: "100.64.0.1", wantOk: false}, - {name: "benchmarking range rejected", ip: "198.18.1.1", wantOk: false}, - {name: "future-use IPv4 rejected", ip: "240.0.0.2", wantOk: false}, - {name: "TEST-NET-1 rejected", ip: "192.0.2.1", wantOk: false}, - {name: "TEST-NET-2 rejected", ip: "198.51.100.1", wantOk: false}, - {name: "TEST-NET-3 rejected", ip: "203.0.113.1", wantOk: false}, - {name: "IPv6 doc rejected", ip: "2001:db8::1", wantOk: false}, - {name: "IPv6 benchmarking rejected", ip: "2001:2::1", wantOk: false}, - {name: "IPv6 NAT64 well-known rejected", ip: "64:ff9b::808:808", wantOk: false}, - {name: "Private allowed when configured", ip: "192.168.1.1", wantOk: true}, - {name: "Public allowed", ip: "8.8.8.8", wantOk: true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ip := netip.MustParseAddr(tt.ip) - got := extractor.isPlausibleClientIP(ip) - if got != tt.wantOk { - t.Errorf("isPlausibleClientIP(%s) = %v, want %v", tt.ip, got, tt.wantOk) - } - }) - } -} - -func TestIsPlausibleClientIP_WithAllowedReservedClientPrefixes(t *testing.T) { - extractor := &Extractor{ - config: &config{ - allowReservedClientPrefixes: []netip.Prefix{ - netip.MustParsePrefix("100.64.0.0/10"), - netip.MustParsePrefix("2001:db8::/32"), - }, - metrics: noopMetrics{}, - }, - } - - tests := []struct { - name string - ip string - wantOk bool - }{ - {name: "allowlisted reserved IPv4", ip: "100.64.0.1", wantOk: true}, - {name: "non-allowlisted reserved IPv4", ip: "198.51.100.1", wantOk: false}, - {name: "allowlisted reserved IPv6", ip: "2001:db8::1", wantOk: true}, - {name: "non-allowlisted reserved IPv6", ip: "64:ff9b::808:808", wantOk: false}, - {name: "private remains rejected", ip: "192.168.1.1", wantOk: false}, - {name: "public remains allowed", ip: "8.8.8.8", wantOk: true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ip := netip.MustParseAddr(tt.ip) - got := extractor.isPlausibleClientIP(ip) - if got != tt.wantOk { - t.Errorf("isPlausibleClientIP(%s) = %v, want %v", tt.ip, got, tt.wantOk) - } - }) - } -} From 360d03b73eb9c1ec4d7685086904dbad98f19f9c Mon Sep 17 00:00:00 2001 From: Thomas de Jong Date: Tue, 21 Apr 2026 21:17:45 +0200 Subject: [PATCH 2/5] fix: restore XFF chain compatibility and harden Forwarded parsing --- parse_ip.go | 36 +++++++++++++----- parse_ip_test.go | 38 +++++++++++++++++++ source_chain_extract.go | 10 ++++- source_chain_extract_test.go | 71 +++++++++++++++++++++++++++++++++++- source_execution.go | 2 + trust_benchmark_test.go | 4 +- trust_chain.go | 14 +++---- trust_chain_test.go | 6 +-- 8 files changed, 155 insertions(+), 26 deletions(-) diff --git a/parse_ip.go b/parse_ip.go index 78498bd..a20697d 100644 --- a/parse_ip.go +++ b/parse_ip.go @@ -16,26 +16,42 @@ func normalizeIP(ip netip.Addr) netip.Addr { } // parseChainIP parses an IP from a chain value that has already been -// extracted and trimmed by a header parser (XFF, Forwarded). -// It handles plain IPs and the [ip]:port format from Forwarded headers, -// but skips the quote-stripping and fallback paths of parseIP. +// extracted and trimmed by a header parser. +// +// This is intentionally stricter than parseIP: it accepts bare IPs, +// bracketed IPs, and bracketed IPs with a numeric port suffix only. func parseChainIP(s string) netip.Addr { ip, err := netip.ParseAddr(s) if err == nil { return ip } - // Handle [ip]:port from Forwarded header values. - // Extract the content between [ and ] without importing net. - if len(s) > 2 && s[0] == '[' { - if end := strings.IndexByte(s, ']'); end > 1 { - ip, err = netip.ParseAddr(s[1:end]) - if err == nil { - return ip + if len(s) < 2 || s[0] != '[' { + return netip.Addr{} + } + + end := strings.IndexByte(s, ']') + if end <= 1 { + return netip.Addr{} + } + + rest := s[end+1:] + if len(rest) > 0 { + if rest[0] != ':' || len(rest) == 1 { + return netip.Addr{} + } + for i := 1; i < len(rest); i++ { + if rest[i] < '0' || rest[i] > '9' { + return netip.Addr{} } } } + ip, err = netip.ParseAddr(s[1:end]) + if err == nil { + return ip + } + return netip.Addr{} } diff --git a/parse_ip_test.go b/parse_ip_test.go index c774e60..7944a87 100644 --- a/parse_ip_test.go +++ b/parse_ip_test.go @@ -105,6 +105,44 @@ func Test_parseRemoteAddr(t *testing.T) { } } +func TestParseChainIP(t *testing.T) { + tests := []struct { + name string + input string + want netip.Addr + wantErr bool + }{ + {name: "bare ipv4", input: "203.0.113.1", want: netip.MustParseAddr("203.0.113.1")}, + {name: "bracketed ipv6", input: "[2001:db8::1]", want: netip.MustParseAddr("2001:db8::1")}, + {name: "bracketed ipv6 with port", input: "[2001:db8::1]:443", want: netip.MustParseAddr("2001:db8::1")}, + {name: "xff style host port rejected", input: "203.0.113.1:443", wantErr: true}, + {name: "quoted value rejected", input: `"203.0.113.1"`, wantErr: true}, + {name: "trailing junk rejected", input: "[2001:db8::1]junk", wantErr: true}, + {name: "non numeric port rejected", input: "[2001:db8::1]:https", wantErr: true}, + {name: "missing port digits rejected", input: "[2001:db8::1]:", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseChainIP(tt.input) + if tt.wantErr { + if got.IsValid() { + t.Errorf("parseChainIP(%q) = %v, want invalid", tt.input, got) + } + return + } + + if !got.IsValid() { + t.Errorf("parseChainIP(%q) = invalid, want %v", tt.input, tt.want) + return + } + if got != tt.want { + t.Errorf("parseChainIP(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + func TestNormalizeIP(t *testing.T) { tests := []struct { name string diff --git a/source_chain_extract.go b/source_chain_extract.go index c196569..7fb060e 100644 --- a/source_chain_extract.go +++ b/source_chain_extract.go @@ -8,6 +8,7 @@ import ( type chainPolicy struct { headerName string parseValues func([]string) ([]string, error) + parseClientIP func(string) netip.Addr clientIP clientIPPolicy trustedProxy proxyPolicy selection ChainSelection @@ -90,11 +91,16 @@ func (e chainExtractor) extract(req requestView, source Source) (Extraction, *ex } func (e chainExtractor) analyzeChain(parts []string) (chainAnalysis, netip.Addr, error) { + parseClientIP := e.policy.parseClientIP + if parseClientIP == nil { + parseClientIP = parseIP + } + if e.policy.selection == LeftmostUntrustedIP { - return analyzeChainLeftmost(parts, e.policy.trustedProxy, e.policy.collectDebugInfo) + return analyzeChainLeftmost(parts, e.policy.trustedProxy, e.policy.collectDebugInfo, parseClientIP) } - return analyzeChainRightmost(parts, e.policy.trustedProxy, e.policy.collectDebugInfo) + return analyzeChainRightmost(parts, e.policy.trustedProxy, e.policy.collectDebugInfo, parseClientIP) } func (e chainExtractor) chainSeparator() string { diff --git a/source_chain_extract_test.go b/source_chain_extract_test.go index 327d2af..3d533bb 100644 --- a/source_chain_extract_test.go +++ b/source_chain_extract_test.go @@ -76,8 +76,9 @@ func TestChainExtractor_SingleValidValue(t *testing.T) { func TestChainExtractor_ChainWithTrustedProxies(t *testing.T) { trustedCIDR := netip.MustParsePrefix("10.0.0.0/8") ext := chainExtractor{policy: chainPolicy{ - headerName: "X-Forwarded-For", - parseValues: simpleXFFParse, + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + parseClientIP: parseIP, trustedProxy: proxyPolicy{ TrustedProxyCIDRs: []netip.Prefix{trustedCIDR}, TrustedProxyMatch: newPrefixMatcher([]netip.Prefix{trustedCIDR}), @@ -110,6 +111,72 @@ func TestChainExtractor_ChainWithTrustedProxies(t *testing.T) { } } +func TestChainExtractor_XFFAllowsHostPortAndQuotedValues(t *testing.T) { + trustedCIDR := netip.MustParsePrefix("10.0.0.0/8") + ext := chainExtractor{policy: chainPolicy{ + headerName: "X-Forwarded-For", + parseValues: simpleXFFParse, + parseClientIP: parseIP, + trustedProxy: proxyPolicy{ + TrustedProxyCIDRs: []netip.Prefix{trustedCIDR}, + TrustedProxyMatch: newPrefixMatcher([]netip.Prefix{trustedCIDR}), + }, + selection: RightmostUntrustedIP, + }} + + req := requestView{ + remoteAddrValue: "10.0.0.3:8080", + headerMap: map[string][]string{ + "X-Forwarded-For": {`"8.8.8.8", 10.0.0.1:8443, 10.0.0.2:8443`}, + }, + } + + result, failure, err := ext.extract(req, SourceXForwardedFor) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure != nil { + t.Fatalf("unexpected failure: %+v", failure) + } + if result.IP != netip.MustParseAddr("8.8.8.8") { + t.Fatalf("IP = %v, want 8.8.8.8", result.IP) + } + if result.TrustedProxyCount != 2 { + t.Fatalf("TrustedProxyCount = %d, want 2", result.TrustedProxyCount) + } +} + +func TestChainExtractor_ForwardedRejectsBracketedTrailingJunk(t *testing.T) { + ext := chainExtractor{policy: chainPolicy{ + headerName: "Forwarded", + parseClientIP: parseChainIP, + parseValues: func(values []string) ([]string, error) { + return parseForwardedValues(values, 8) + }, + selection: RightmostUntrustedIP, + }} + + req := requestView{ + headerMap: map[string][]string{ + "Forwarded": {"for=[2001:db8::1]junk"}, + }, + } + + _, failure, err := ext.extract(req, SourceForwarded) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if failure == nil { + t.Fatal("expected invalid client IP failure") + } + if failure.kind != failureInvalidClientIP { + t.Fatalf("failure.kind = %v, want failureInvalidClientIP", failure.kind) + } + if failure.extractedIP != "[2001:db8::1]junk" { + t.Fatalf("failure.extractedIP = %q, want %q", failure.extractedIP, "[2001:db8::1]junk") + } +} + func TestChainExtractor_EmptyChainAfterParse(t *testing.T) { ext := chainExtractor{policy: chainPolicy{ headerName: "X-Forwarded-For", diff --git a/source_execution.go b/source_execution.go index dad1636..d1b5020 100644 --- a/source_execution.go +++ b/source_execution.go @@ -64,6 +64,7 @@ func (e *Extractor) compileExecutor(spec sourceSpec, configuredSource Source) so } return parts, nil }, + parseClientIP: parseChainIP, clientIP: e.clientIP, trustedProxy: e.proxy, selection: e.config.chainSelection, @@ -105,6 +106,7 @@ func (e *Extractor) compileExecutor(spec sourceSpec, configuredSource Source) so } return parts, nil }, + parseClientIP: parseIP, clientIP: e.clientIP, trustedProxy: e.proxy, selection: e.config.chainSelection, diff --git a/trust_benchmark_test.go b/trust_benchmark_test.go index 168c2f0..f6b8a02 100644 --- a/trust_benchmark_test.go +++ b/trust_benchmark_test.go @@ -77,7 +77,7 @@ func BenchmarkChainAnalysisRightmost(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, err := analyzeChainRightmost(parts, policy, true) + _, _, err := analyzeChainRightmost(parts, policy, true, parseIP) if err != nil { b.Fatal(err) } @@ -95,7 +95,7 @@ func BenchmarkChainAnalysisLeftmost(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, err := analyzeChainLeftmost(parts, policy, true) + _, _, err := analyzeChainLeftmost(parts, policy, true, parseIP) if err != nil { b.Fatal(err) } diff --git a/trust_chain.go b/trust_chain.go index 11fb9af..3dc6ea2 100644 --- a/trust_chain.go +++ b/trust_chain.go @@ -51,7 +51,7 @@ func validateProxyCountPolicy(trustedCount int, policy proxyPolicy) error { return nil } -func analyzeChainRightmost(parts []string, policy proxyPolicy, collectTrustedIndices bool) (chainAnalysis, netip.Addr, error) { +func analyzeChainRightmost(parts []string, policy proxyPolicy, collectTrustedIndices bool, parseClientIP func(string) netip.Addr) (chainAnalysis, netip.Addr, error) { trustedCount := 0 clientIndex := 0 clientIP := netip.Addr{} @@ -66,11 +66,11 @@ func analyzeChainRightmost(parts []string, policy proxyPolicy, collectTrustedInd for i := len(parts) - 1; i >= 0; i-- { if !hasCIDRs && policy.MaxTrustedProxies > 0 && trustedCount >= policy.MaxTrustedProxies { clientIndex = i - clientIP = parseChainIP(parts[i]) + clientIP = parseClientIP(parts[i]) break } - ip := parseChainIP(parts[i]) + ip := parseClientIP(parts[i]) if hasCIDRs && !isTrustedProxy(ip, policy.TrustedProxyMatch, policy.TrustedProxyCIDRs) { clientIndex = i @@ -98,10 +98,10 @@ func analyzeChainRightmost(parts []string, policy proxyPolicy, collectTrustedInd return analysis, clientIP, nil } -func analyzeChainLeftmost(parts []string, policy proxyPolicy, collectTrustedIndices bool) (chainAnalysis, netip.Addr, error) { +func analyzeChainLeftmost(parts []string, policy proxyPolicy, collectTrustedIndices bool, parseClientIP func(string) netip.Addr) (chainAnalysis, netip.Addr, error) { if len(policy.TrustedProxyCIDRs) == 0 { analysis := chainAnalysis{ClientIndex: 0, TrustedCount: 0} - return analysis, parseIP(parts[0]), nil + return analysis, parseClientIP(parts[0]), nil } trustedCount := 0 @@ -121,7 +121,7 @@ func analyzeChainLeftmost(parts []string, policy proxyPolicy, collectTrustedIndi stillTrailingTrusted := true for i := len(parts) - 1; i >= 0; i-- { - ip := parseChainIP(parts[i]) + ip := parseClientIP(parts[i]) trusted := isTrustedProxy(ip, policy.TrustedProxyMatch, policy.TrustedProxyCIDRs) if stillTrailingTrusted && trusted { @@ -166,5 +166,5 @@ func analyzeChainLeftmost(parts []string, policy proxyPolicy, collectTrustedIndi } analysis.ClientIndex = 0 - return analysis, parseChainIP(parts[analysis.ClientIndex]), nil + return analysis, parseClientIP(parts[analysis.ClientIndex]), nil } diff --git a/trust_chain_test.go b/trust_chain_test.go index 57c3697..c950eee 100644 --- a/trust_chain_test.go +++ b/trust_chain_test.go @@ -63,7 +63,7 @@ func TestAnalyzeChainRightmostNoCIDRs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - analysis, _, err := analyzeChainRightmost(tt.parts, proxyPolicy{MaxTrustedProxies: tt.maxTrustedProxies}, true) + analysis, _, err := analyzeChainRightmost(tt.parts, proxyPolicy{MaxTrustedProxies: tt.maxTrustedProxies}, true, parseIP) if err != nil { t.Fatalf("analyzeChainRightmost() error = %v", err) } @@ -107,7 +107,7 @@ func TestAnalyzeChainRightmostWithCIDRs(t *testing.T) { active.MinTrustedProxies = tt.minProxies active.MaxTrustedProxies = tt.maxProxies - analysis, _, err := analyzeChainRightmost(tt.parts, active, true) + analysis, _, err := analyzeChainRightmost(tt.parts, active, true, parseIP) if tt.wantErr == nil { if err != nil { t.Fatalf("analyzeChainRightmost() error = %v, want nil", err) @@ -164,7 +164,7 @@ func TestAnalyzeChainLeftmostWithCIDRs(t *testing.T) { active.MinTrustedProxies = tt.minProxies active.MaxTrustedProxies = tt.maxProxies - analysis, _, err := analyzeChainLeftmost(tt.parts, active, true) + analysis, _, err := analyzeChainLeftmost(tt.parts, active, true, parseIP) if tt.wantErr == nil { if err != nil { t.Fatalf("analyzeChainLeftmost() error = %v, want nil", err) From ba8377f4b79b8a0c8f60dcd32fd2d576e181c77b Mon Sep 17 00:00:00 2001 From: Thomas de Jong Date: Wed, 22 Apr 2026 06:53:07 +0200 Subject: [PATCH 3/5] fix: fail closed on malformed empty Forwarded segments --- CHANGELOG.md | 1 + README.md | 2 ++ extractor_test.go | 42 ++++++++++++++++++--------- observability_test.go | 63 ++++++++++++++++++++++++++++------------- parse_forwarded.go | 15 +++++----- parse_forwarded_test.go | 11 +++++++ parse_fuzz_test.go | 2 +- 7 files changed, 94 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 50db953..5e7e716 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ The format is based on Keep a Changelog and this project follows Semantic Versio - **BREAKING:** `SourceStaticFallback` remains public but is resolver-result-only; it cannot be used in `Config.Sources`. - Presets remain `Config` helpers and now document resolver-oriented usage more clearly. - Prometheus integration is constructor-based: build metrics with `prometheus.New()` or `prometheus.NewWithRegisterer(...)` and assign them through `Config.Metrics`. +- `X-Forwarded-For` chain extraction again accepts the host:port and quoted forms already supported by `parseIP`, while `Forwarded` stays strict and now rejects present-but-empty values plus empty delimiter-created elements/parameters as malformed. - Internal orchestration now sits behind `internal/engine` and concrete source execution behind `internal/source`. ### Removed diff --git a/README.md b/README.md index 2c8e4e9..dcea79c 100644 --- a/README.md +++ b/README.md @@ -281,6 +281,8 @@ case clientip.ResultUnknown: `ResultUnknown` covers non-nil errors outside the package's standard extraction and resolution categories. +`ErrInvalidForwardedHeader` covers malformed RFC7239 syntax, including present-but-empty `Forwarded` values and empty elements or parameters introduced by stray delimiters. In strict extraction, malformed `Forwarded` remains terminal and does not fall through to a lower-priority source. + Typed chain-related errors expose additional context: - `ProxyValidationError`: `Chain`, `TrustedProxyCount`, `MinTrustedProxies`, `MaxTrustedProxies` diff --git a/extractor_test.go b/extractor_test.go index 535a0eb..ad813f5 100644 --- a/extractor_test.go +++ b/extractor_test.go @@ -522,22 +522,36 @@ func TestExtract_StrictMode_MalformedForwarded_IsTerminal(t *testing.T) { cfg.Sources = []Source{SourceForwarded, SourceRemoteAddr} extractor := mustNewExtractor(t, cfg) - req := &http.Request{ - RemoteAddr: "1.1.1.1:8080", - Header: make(http.Header), + tests := []struct { + name string + forwarded string + }{ + {name: "unterminated quoted value", forwarded: "for=\"1.1.1.1"}, + {name: "empty header value", forwarded: ""}, + {name: "empty element between commas", forwarded: "for=1.1.1.1,, for=8.8.8.8"}, + {name: "empty parameter between semicolons", forwarded: "for=1.1.1.1;;proto=https"}, } - req.Header.Set("Forwarded", "for=\"1.1.1.1") - req.Header.Set("X-Forwarded-For", "8.8.8.8") - result, err := extractor.Extract(req) - if err == nil && result.IP.IsValid() { - t.Fatal("expected extraction to fail closed on malformed Forwarded") - } - if !errors.Is(err, ErrInvalidForwardedHeader) { - t.Fatalf("error = %v, want ErrInvalidForwardedHeader", err) - } - if result.Source != SourceForwarded { - t.Fatalf("source = %q, want %q", result.Source, SourceForwarded) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &http.Request{ + RemoteAddr: "1.1.1.1:8080", + Header: make(http.Header), + } + req.Header.Set("Forwarded", tt.forwarded) + req.Header.Set("X-Forwarded-For", "8.8.8.8") + + result, err := extractor.Extract(req) + if err == nil && result.IP.IsValid() { + t.Fatal("expected extraction to fail closed on malformed Forwarded") + } + if !errors.Is(err, ErrInvalidForwardedHeader) { + t.Fatalf("error = %v, want ErrInvalidForwardedHeader", err) + } + if result.Source != SourceForwarded { + t.Fatalf("source = %q, want %q", result.Source, SourceForwarded) + } + }) } } diff --git a/observability_test.go b/observability_test.go index 0cee6a9..47a4339 100644 --- a/observability_test.go +++ b/observability_test.go @@ -39,6 +39,13 @@ func (l *capturedLogger) snapshot() []capturedLogEntry { return entries } +func (l *capturedLogger) clear() { + l.mu.Lock() + defer l.mu.Unlock() + + l.entries = nil +} + func attrsToMap(args []any) map[string]any { attrs := make(map[string]any) for i := 0; i+1 < len(args); i += 2 { @@ -364,27 +371,43 @@ func TestLogging_MalformedForwarded_EmitsWarning(t *testing.T) { cfg.Sources = []Source{SourceForwarded, SourceRemoteAddr} extractor := mustNewExtractor(t, cfg) - req := newTestRequest("1.1.1.1:8080", "/test/malformed-forwarded") - req.Header.Set("Forwarded", "for=\"1.1.1.1") - - result, err := extractor.Extract(req) - if err == nil && result.IP.IsValid() { - t.Fatal("expected extraction to fail closed on malformed Forwarded") - } - if !errors.Is(err, ErrInvalidForwardedHeader) { - t.Fatalf("error = %v, want ErrInvalidForwardedHeader", err) - } - if result.Source != SourceForwarded { - t.Fatalf("source = %q, want %q", result.Source, SourceForwarded) - } - - entries := logger.snapshot() - if len(entries) != 1 { - t.Fatalf("logged entries = %d, want 1", len(entries)) + tests := []struct { + name string + forwarded string + path string + }{ + {name: "unterminated quoted value", forwarded: "for=\"1.1.1.1", path: "/test/malformed-forwarded/unterminated"}, + {name: "empty header value", forwarded: "", path: "/test/malformed-forwarded/empty"}, + {name: "empty element between commas", forwarded: "for=1.1.1.1,, for=8.8.8.8", path: "/test/malformed-forwarded/empty-element"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger.clear() + + req := newTestRequest("1.1.1.1:8080", tt.path) + req.Header.Set("Forwarded", tt.forwarded) + + result, err := extractor.Extract(req) + if err == nil && result.IP.IsValid() { + t.Fatal("expected extraction to fail closed on malformed Forwarded") + } + if !errors.Is(err, ErrInvalidForwardedHeader) { + t.Fatalf("error = %v, want ErrInvalidForwardedHeader", err) + } + if result.Source != SourceForwarded { + t.Fatalf("source = %q, want %q", result.Source, SourceForwarded) + } + + entries := logger.snapshot() + if len(entries) != 1 { + t.Fatalf("logged entries = %d, want 1", len(entries)) + } + + entry := entries[0] + assertCommonSecurityWarningAttrs(t, entry.attrs, SecurityEventMalformedForwarded, SourceForwarded, tt.path, "1.1.1.1:8080") + }) } - - entry := entries[0] - assertCommonSecurityWarningAttrs(t, entry.attrs, SecurityEventMalformedForwarded, SourceForwarded, "/test/malformed-forwarded", "1.1.1.1:8080") } func TestExtractInput_UsesInputContextAndPathInLogs(t *testing.T) { diff --git a/parse_forwarded.go b/parse_forwarded.go index 050ceb1..9829634 100644 --- a/parse_forwarded.go +++ b/parse_forwarded.go @@ -14,7 +14,7 @@ func parseForwardedValues(values []string, maxChainLength int) ([]string, error) parts := make([]string, 0, chainPartsCapacity(values, maxChainLength)) for _, value := range values { - err := scanForwardedSegments(value, ',', func(element string) error { + err := scanForwardedSegments(value, ',', "element", func(element string) error { forwardedFor, hasFor, parseErr := parseForwardedElement(element) if parseErr != nil { return parseErr @@ -47,7 +47,7 @@ func parseForwardedValues(values []string, maxChainLength int) ([]string, error) } func parseForwardedElement(element string) (forwardedFor string, hasFor bool, err error) { - err = scanForwardedSegments(element, ';', func(param string) error { + err = scanForwardedSegments(element, ';', "parameter", func(param string) error { eq := strings.IndexByte(param, '=') if eq <= 0 { return fmt.Errorf("invalid forwarded parameter %q", param) @@ -86,7 +86,7 @@ func parseForwardedElement(element string) (forwardedFor string, hasFor bool, er return forwardedFor, hasFor, nil } -func scanForwardedSegments(value string, delimiter byte, onSegment func(string) error) error { +func scanForwardedSegments(value string, delimiter byte, segmentKind string, onSegment func(string) error) error { start := 0 inQuotes := false escaped := false @@ -123,10 +123,11 @@ func scanForwardedSegments(value string, delimiter byte, onSegment func(string) } segment := strings.TrimSpace(value[start:i]) - if segment != "" { - if err := onSegment(segment); err != nil { - return err - } + if segment == "" { + return fmt.Errorf("empty forwarded %s in %q", segmentKind, value) + } + if err := onSegment(segment); err != nil { + return err } start = i + 1 diff --git a/parse_forwarded_test.go b/parse_forwarded_test.go index 0e9a290..5f0acff 100644 --- a/parse_forwarded_test.go +++ b/parse_forwarded_test.go @@ -24,6 +24,13 @@ func TestParseForwardedValues(t *testing.T) { {name: "quoted semicolon is not treated as param delimiter", values: []string{"for=\"1.1.1.1;edge\";proto=https"}, want: []string{"1.1.1.1;edge"}}, {name: "escaped quote remains inside quoted value", values: []string{`for="1.1.1.1\";edge";proto=https`}, want: []string{`1.1.1.1";edge`}}, {name: "ignores element without for parameter", values: []string{"proto=https;by=10.0.0.1, for=8.8.8.8"}, want: []string{"8.8.8.8"}}, + {name: "empty header value", values: []string{""}, wantErr: true}, + {name: "whitespace only header value", values: []string{" "}, wantErr: true}, + {name: "leading comma creates empty element", values: []string{", for=1.1.1.1"}, wantErr: true}, + {name: "trailing comma creates empty element", values: []string{"for=1.1.1.1,"}, wantErr: true}, + {name: "double comma creates empty element", values: []string{"for=1.1.1.1,, for=8.8.8.8"}, wantErr: true}, + {name: "whitespace only empty element", values: []string{"for=1.1.1.1, , for=8.8.8.8"}, wantErr: true}, + {name: "empty line among multiple header lines is malformed", values: []string{"for=1.1.1.1", " "}, wantErr: true}, {name: "invalid parameter format", values: []string{"for"}, wantErr: true}, {name: "unterminated quoted string", values: []string{"for=\"1.1.1.1"}, wantErr: true}, {name: "duplicate for parameter", values: []string{"for=1.1.1.1;for=8.8.8.8"}, wantErr: true}, @@ -67,6 +74,10 @@ func TestParseForwardedValues_MalformedParameterMatrix(t *testing.T) { values []string }{ {name: "empty parameter key", values: []string{"=1.1.1.1"}}, + {name: "leading semicolon creates empty parameter", values: []string{";for=1.1.1.1"}}, + {name: "trailing semicolon creates empty parameter", values: []string{"for=1.1.1.1;"}}, + {name: "double semicolon creates empty parameter", values: []string{"for=1.1.1.1;;proto=https"}}, + {name: "whitespace only empty parameter", values: []string{"for=1.1.1.1; ;proto=https"}}, {name: "empty for value", values: []string{"for="}}, {name: "empty quoted for value", values: []string{`for=""`}}, {name: "invalid quoted for value suffix", values: []string{`for="1.1.1.1"extra`}}, diff --git a/parse_fuzz_test.go b/parse_fuzz_test.go index 7dd8129..dc79b64 100644 --- a/parse_fuzz_test.go +++ b/parse_fuzz_test.go @@ -85,7 +85,7 @@ func FuzzParseXFFValues_ErrorShapeAndOutput(f *testing.F) { } func FuzzParseForwardedValues_ErrorShapeAndOutput(f *testing.F) { - for _, seed := range []string{"for=1.1.1.1", "for=1.1.1.1, for=8.8.8.8", "for=1.1.1.1;proto=https", `for="[2606:4700:4700::1]:443"`, `for="1.1.1.1\"edge"`, "for", `for="unterminated`, ""} { + for _, seed := range []string{"for=1.1.1.1", "for=1.1.1.1, for=8.8.8.8", "for=1.1.1.1;proto=https", `for="[2606:4700:4700::1]:443"`, `for="1.1.1.1\"edge"`, "for", `for="unterminated`, "", " ", ",for=1.1.1.1", "for=1.1.1.1,", "for=1.1.1.1,,for=8.8.8.8", ";for=1.1.1.1", "for=1.1.1.1;", "for=1.1.1.1;;proto=https"} { f.Add(seed) } From a672bacb814545a420f365af6a51f61b58c6cbf8 Mon Sep 17 00:00:00 2001 From: Thomas de Jong Date: Thu, 23 Apr 2026 20:31:26 +0200 Subject: [PATCH 4/5] refactor: simplify source extraction orchestration --- extractor.go | 183 ++++++++++++++++++++---- extractor_orchestration_test.go | 55 +++++++ resolver_test.go | 48 +++---- source_build_test.go | 69 --------- source_chained.go | 56 -------- source_chained_test.go | 246 -------------------------------- source_compile.go | 45 ------ source_execution.go | 188 +++++------------------- 8 files changed, 264 insertions(+), 626 deletions(-) create mode 100644 extractor_orchestration_test.go delete mode 100644 source_build_test.go delete mode 100644 source_chained.go delete mode 100644 source_chained_test.go delete mode 100644 source_compile.go diff --git a/extractor.go b/extractor.go index 5ca0fb8..ce04ff1 100644 --- a/extractor.go +++ b/extractor.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/netip" + "net/textproto" ) // Extractor resolves client IP information from HTTP requests and @@ -12,10 +13,20 @@ import ( // // Extractor instances are safe for concurrent reuse. type Extractor struct { - config *config - source sourceExtractor - clientIP clientIPPolicy - proxy proxyPolicy + config *config + sources []configuredSource + clientIP clientIPPolicy + proxy proxyPolicy + extractViewFunc func(requestView) (Extraction, error) +} + +type configuredSource struct { + source Source + name string + unavailableErr *ExtractionError + chain chainExtractor + single singleHeaderExtractor + remote remoteAddrExtractor } // New creates an Extractor from a Config. @@ -38,7 +49,7 @@ func New(public Config) (*Extractor, error) { MaxTrustedProxies: cfg.maxTrustedProxies, }, } - extractor.source = extractor.buildSourceChain(cfg) + extractor.sources = extractor.buildConfiguredSources(cfg.sourcePriority) return extractor, nil } @@ -56,7 +67,7 @@ func (e *Extractor) Extract(r *http.Request) (Extraction, error) { return e.extractFromRemoteAddr(r.RemoteAddr) } - return e.extractWithSource(e.source, requestViewFromRequest(r)) + return e.extractRequestView(requestViewFromRequest(r)) } // ExtractAddr resolves only the client IP address. @@ -81,7 +92,7 @@ func (e *Extractor) ExtractInput(input Input) (Extraction, error) { return e.extractFromRemoteAddr(input.RemoteAddr) } - return e.extractWithSource(e.source, requestViewFromInput(input)) + return e.extractRequestView(requestViewFromInput(input)) } // ExtractInputAddr resolves only the client IP address from framework-agnostic @@ -95,22 +106,80 @@ func (e *Extractor) ExtractInputAddr(input Input) (netip.Addr, error) { return extraction.IP, nil } -func (e *Extractor) extractWithSource(source sourceExtractor, r requestView) (Extraction, error) { +func (e *Extractor) extractRequestView(r requestView) (Extraction, error) { if err := r.context().Err(); err != nil { return Extraction{}, err } - result, err := source.extract(r) - if err != nil { - fallbackSource := source.sourceInfo() - if !result.Source.valid() { - result.Source = fallbackSource + if e.extractViewFunc != nil { + result, err := e.extractViewFunc(r) + if err != nil && !result.Source.valid() { + result.Source = sourceValueFromError(err) } - result.Source = e.getSource(result, err) return result, err } - return result, nil + for i := range e.sources { + source := &e.sources[i] + if i > 0 { + if err := r.context().Err(); err != nil { + return Extraction{}, err + } + } + + var ( + result Extraction + err error + ) + + switch source.source.kind { + case sourceForwarded: + result, err = e.extractChainSource( + r, + source, + "Forwarded chain exceeds configured maximum length", + "request received from untrusted proxy while Forwarded is present", + func(err error) { + if !errors.Is(err, ErrInvalidForwardedHeader) { + return + } + e.config.metrics.RecordSecurityEvent(SecurityEventMalformedForwarded) + e.logSecurityWarning(r, source.source, SecurityEventMalformedForwarded, "malformed Forwarded header received", "parse_error", err.Error()) + }, + ) + case sourceXForwardedFor: + result, err = e.extractChainSource( + r, + source, + "X-Forwarded-For chain exceeds configured maximum length", + "request received from untrusted proxy while X-Forwarded-For is present", + nil, + ) + case sourceRemoteAddr: + result, err = e.extractRemoteAddrSource(r, source) + default: + result, err = e.extractSingleHeaderSource(r, source) + } + if err == nil { + return result, nil + } + + if sourceIsTerminalError(err) { + if !result.Source.valid() { + result.Source = sourceValueFromError(err) + } + return result, err + } + + if i == len(e.sources)-1 { + if !result.Source.valid() { + result.Source = sourceValueFromError(err) + } + return result, err + } + } + + return Extraction{}, ErrSourceUnavailable } func (e *Extractor) extractFromRemoteAddr(remoteAddr string) (Extraction, error) { @@ -122,7 +191,7 @@ func (e *Extractor) extractFromRemoteAddr(remoteAddr string) (Extraction, error) e.config.metrics.RecordExtractionFailure(source.String()) } err := adaptRemoteAddrFailure(failure, source) - result.Source = e.getSource(result, err) + result.Source = sourceValueFromError(err) return result, err } @@ -130,19 +199,77 @@ func (e *Extractor) extractFromRemoteAddr(remoteAddr string) (Extraction, error) return result, nil } -// getSource resolves the authoritative source for a result. -// -// Precedence: error-embedded source > result source > extractor default. -func (e *Extractor) getSource(result Extraction, err error) Source { - if err != nil { - var sourceErr interface{ SourceValue() Source } - if errors.As(err, &sourceErr) { - return sourceErr.SourceValue() +func (e *Extractor) buildConfiguredSources(sources []Source) []configuredSource { + configured := make([]configuredSource, len(sources)) + for i, source := range sources { + source := source + headerName, _ := sourceHeaderKey(source) + if headerName != "" { + headerName = textproto.CanonicalMIMEHeaderKey(headerName) } - return Source{} + + configuredSource := configuredSource{ + source: source, + name: source.String(), + unavailableErr: &ExtractionError{Err: ErrSourceUnavailable, Source: source}, + } + + switch source.kind { + case sourceForwarded: + configuredSource.chain = chainExtractor{policy: chainPolicy{ + headerName: headerName, + parseValues: func(values []string) ([]string, error) { + parts, err := parseForwardedValues(values, e.config.maxChainLength) + if err != nil { + return nil, adaptForwardedParseError(err, source, e) + } + return parts, nil + }, + parseClientIP: parseChainIP, + clientIP: e.clientIP, + trustedProxy: e.proxy, + selection: e.config.chainSelection, + collectDebugInfo: e.config.debugMode, + untrustedChainSep: ", ", + }} + case sourceXForwardedFor: + configuredSource.chain = chainExtractor{policy: chainPolicy{ + headerName: headerName, + parseValues: func(values []string) ([]string, error) { + parts, err := parseXFFValues(values, e.config.maxChainLength) + if err != nil { + return nil, adaptXFFParseError(err, source, e) + } + return parts, nil + }, + parseClientIP: parseIP, + clientIP: e.clientIP, + trustedProxy: e.proxy, + selection: e.config.chainSelection, + collectDebugInfo: e.config.debugMode, + untrustedChainSep: ", ", + }} + case sourceRemoteAddr: + configuredSource.remote = remoteAddrExtractor{clientIPPolicy: e.clientIP} + default: + configuredSource.single = singleHeaderExtractor{policy: singleHeaderPolicy{ + headerName: headerName, + clientIP: e.clientIP, + trustedProxy: e.proxy, + }} + } + + configured[i] = configuredSource } - if result.Source.valid() { - return result.Source + + return configured +} + +func sourceValueFromError(err error) Source { + var sourceErr interface{ SourceValue() Source } + if errors.As(err, &sourceErr) { + return sourceErr.SourceValue() } - return e.source.sourceInfo() + + return Source{} } diff --git a/extractor_orchestration_test.go b/extractor_orchestration_test.go new file mode 100644 index 0000000..9c39b03 --- /dev/null +++ b/extractor_orchestration_test.go @@ -0,0 +1,55 @@ +package clientip + +import ( + "context" + "errors" + "net/http" + "net/textproto" + "testing" +) + +func TestExtract_AllSourcesUnavailableReturnsLastSource(t *testing.T) { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{SourceXRealIP, HeaderSource("CF-Connecting-IP")} + extractor := mustNewExtractor(t, cfg) + + req := &http.Request{RemoteAddr: "127.0.0.1:8080", Header: make(http.Header)} + + result, err := extractor.Extract(req) + if !errors.Is(err, ErrSourceUnavailable) { + t.Fatalf("error = %v, want ErrSourceUnavailable", err) + } + if got, want := result.Source, HeaderSource("CF-Connecting-IP"); got != want { + t.Fatalf("source = %q, want %q", got, want) + } +} + +func TestExtractInput_ContextCanceledBeforeFallbackSource(t *testing.T) { + cfg := DefaultConfig() + cfg.TrustedProxyPrefixes = LoopbackProxyPrefixes() + cfg.Sources = []Source{HeaderSource("CF-Connecting-IP"), SourceRemoteAddr} + extractor := mustNewExtractor(t, cfg) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + requestedHeaders := make([]string, 0, 1) + cfHeader := textproto.CanonicalMIMEHeaderKey("CF-Connecting-IP") + _, err := extractor.ExtractInput(Input{ + Context: ctx, + RemoteAddr: "127.0.0.1:8080", + Headers: HeaderValuesFunc(func(name string) []string { + requestedHeaders = append(requestedHeaders, name) + cancel() + return nil + }), + }) + + if !errors.Is(err, context.Canceled) { + t.Fatalf("error = %v, want context.Canceled", err) + } + if len(requestedHeaders) != 1 || requestedHeaders[0] != cfHeader { + t.Fatalf("requested headers = %v, want [%q]", requestedHeaders, cfHeader) + } +} diff --git a/resolver_test.go b/resolver_test.go index 77389d8..be992ad 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -14,15 +14,14 @@ import ( "github.com/google/go-cmp/cmp" ) -type countingResolverSource struct { - calls int - result Extraction - err error - extractFn func(requestView) (Extraction, error) - sourceName Source +type countingResolverExtractor struct { + calls int + result Extraction + err error + extractFn func(requestView) (Extraction, error) } -func (s *countingResolverSource) extract(req requestView) (Extraction, error) { +func (s *countingResolverExtractor) extract(req requestView) (Extraction, error) { s.calls++ if s.extractFn != nil { return s.extractFn(req) @@ -31,28 +30,13 @@ func (s *countingResolverSource) extract(req requestView) (Extraction, error) { return s.result, s.err } -func (s *countingResolverSource) name() string { - return "counting" -} - -func (s *countingResolverSource) sourceInfo() Source { - if s.sourceName.valid() { - return s.sourceName - } - if s.result.Source.valid() { - return s.result.Source - } - - return SourceRemoteAddr -} - -func newResolverTestExtractor(source sourceExtractor) *Extractor { +func newResolverTestExtractor(source *countingResolverExtractor) *Extractor { return &Extractor{ config: &config{ sourcePriority: []Source{HeaderSource("X-Test-IP")}, sourceHeaderKeys: []string{"X-Test-IP"}, }, - source: source, + extractViewFunc: source.extract, } } @@ -68,7 +52,7 @@ func mustNewResolver(t *testing.T, extractor *Extractor, config ResolverConfig) } func TestNewResolver_InvalidConfig(t *testing.T) { - extractor := newResolverTestExtractor(&countingResolverSource{}) + extractor := newResolverTestExtractor(&countingResolverExtractor{}) tests := []struct { name string @@ -116,7 +100,7 @@ func TestNewResolver_InvalidConfig(t *testing.T) { } func TestResolver_ResolveStrict_CachesSuccess(t *testing.T) { - source := &countingResolverSource{ + source := &countingResolverExtractor{ result: Extraction{IP: netip.MustParseAddr("8.8.8.8"), Source: SourceXRealIP}, } resolver := mustNewResolver(t, newResolverTestExtractor(source), ResolverConfig{}) @@ -158,7 +142,7 @@ func TestResolver_ResolveStrict_CachesSuccess(t *testing.T) { func TestResolver_ResolveStrict_CachesFailure(t *testing.T) { strictErr := &ExtractionError{Err: ErrInvalidIP, Source: SourceXRealIP} - source := &countingResolverSource{err: strictErr} + source := &countingResolverExtractor{err: strictErr} resolver := mustNewResolver(t, newResolverTestExtractor(source), ResolverConfig{}) req := &http.Request{RemoteAddr: "203.0.113.10:443", Header: make(http.Header)} @@ -191,7 +175,7 @@ func TestResolver_ResolveStrict_CachesFailure(t *testing.T) { } func TestResolver_ResolvePreferred_ReusesStrictCachedResult(t *testing.T) { - source := &countingResolverSource{ + source := &countingResolverExtractor{ result: Extraction{IP: netip.MustParseAddr("8.8.8.8"), Source: SourceXRealIP}, } resolver := mustNewResolver(t, newResolverTestExtractor(source), ResolverConfig{PreferredFallback: PreferredFallbackRemoteAddr}) @@ -221,7 +205,7 @@ func TestResolver_ResolvePreferred_ReusesStrictCachedResult(t *testing.T) { func TestResolver_ResolvePreferred_ParseRemoteAddrFallback(t *testing.T) { strictErr := &ExtractionError{Err: ErrInvalidIP, Source: SourceXRealIP} - source := &countingResolverSource{err: strictErr} + source := &countingResolverExtractor{err: strictErr} resolver := mustNewResolver(t, newResolverTestExtractor(source), ResolverConfig{PreferredFallback: PreferredFallbackRemoteAddr}) req := &http.Request{RemoteAddr: "127.0.0.1:8080", Header: make(http.Header)} @@ -262,7 +246,7 @@ func TestResolver_ResolvePreferred_ParseRemoteAddrFallback(t *testing.T) { func TestResolver_ResolvePreferred_StaticFallback(t *testing.T) { strictErr := &ExtractionError{Err: ErrInvalidIP, Source: SourceXRealIP} - source := &countingResolverSource{err: strictErr} + source := &countingResolverExtractor{err: strictErr} resolver := mustNewResolver(t, newResolverTestExtractor(source), ResolverConfig{ PreferredFallback: PreferredFallbackStaticIP, StaticFallbackIP: netip.MustParseAddr("0.0.0.0"), @@ -297,7 +281,7 @@ func TestResolver_ResolvePreferred_StaticFallback(t *testing.T) { } func TestResolver_ResolvePreferred_DoesNotFallbackOnCanceledOrDeadline(t *testing.T) { - resolver := mustNewResolver(t, newResolverTestExtractor(&countingResolverSource{ + resolver := mustNewResolver(t, newResolverTestExtractor(&countingResolverExtractor{ extractFn: func(req requestView) (Extraction, error) { return Extraction{}, req.context().Err() }, @@ -360,7 +344,7 @@ func TestResolver_ResolvePreferred_DoesNotFallbackOnCanceledOrDeadline(t *testin } func TestResolver_ResolveInputStrict_CachesSuccess(t *testing.T) { - source := &countingResolverSource{ + source := &countingResolverExtractor{ result: Extraction{IP: netip.MustParseAddr("2001:db8::1"), Source: SourceXRealIP}, } resolver := mustNewResolver(t, newResolverTestExtractor(source), ResolverConfig{}) diff --git a/source_build_test.go b/source_build_test.go deleted file mode 100644 index 81e7147..0000000 --- a/source_build_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package clientip - -import ( - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestCompileSpecFromSource(t *testing.T) { - tests := []struct { - name string - source Source - want sourceSpec - }{ - { - name: "forwarded source", - source: SourceForwarded, - want: sourceSpec{ - kind: sourceExtractorKindForwarded, - source: SourceForwarded, - headerName: "Forwarded", - }, - }, - { - name: "x forwarded for source", - source: SourceXForwardedFor, - want: sourceSpec{ - kind: sourceExtractorKindXForwardedFor, - source: SourceXForwardedFor, - headerName: "X-Forwarded-For", - }, - }, - { - name: "x real ip source", - source: SourceXRealIP, - want: sourceSpec{ - kind: sourceExtractorKindSingleHeader, - source: SourceXRealIP, - headerName: "X-Real-Ip", - }, - }, - { - name: "remote addr source", - source: SourceRemoteAddr, - want: sourceSpec{ - kind: sourceExtractorKindRemoteAddr, - source: SourceRemoteAddr, - headerName: "", - }, - }, - { - name: "custom header source", - source: HeaderSource("x-custom-header"), - want: sourceSpec{ - kind: sourceExtractorKindSingleHeader, - source: HeaderSource("X-Custom-Header"), - headerName: "X-Custom-Header", - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if diff := cmp.Diff(tt.want, compileSpecFromSource(tt.source), cmp.AllowUnexported(sourceSpec{}, Source{})); diff != "" { - t.Fatalf("compileSpecFromSource() mismatch (-want +got):\n%s", diff) - } - }) - } -} diff --git a/source_chained.go b/source_chained.go deleted file mode 100644 index fb39b49..0000000 --- a/source_chained.go +++ /dev/null @@ -1,56 +0,0 @@ -package clientip - -import "strings" - -type chainedSource struct { - sources []sourceExtractor - sourceName string - isTerminal func(error) bool -} - -func newChainedSource(isTerminal func(error) bool, sources ...sourceExtractor) *chainedSource { - names := make([]string, len(sources)) - for i, s := range sources { - names[i] = s.name() - } - - return &chainedSource{ - sources: sources, - sourceName: "chained[" + strings.Join(names, ",") + "]", - isTerminal: isTerminal, - } -} - -func (c *chainedSource) extract(r requestView) (Extraction, error) { - var lastErr error - for i, source := range c.sources { - // Context is already checked by extractWithSource before the first - // source; only re-check between subsequent sources in the chain. - if i > 0 { - if err := r.context().Err(); err != nil { - return Extraction{}, err - } - } - - result, err := source.extract(r) - if err == nil { - return result, nil - } - - if c.isTerminal != nil && c.isTerminal(err) { - return Extraction{}, err - } - - lastErr = err - } - - return Extraction{}, lastErr -} - -func (c *chainedSource) name() string { - return c.sourceName -} - -func (c *chainedSource) sourceInfo() Source { - return Source{} -} diff --git a/source_chained_test.go b/source_chained_test.go deleted file mode 100644 index f19dc55..0000000 --- a/source_chained_test.go +++ /dev/null @@ -1,246 +0,0 @@ -package clientip - -import ( - "context" - "errors" - "net/netip" - "testing" -) - -// mockSourceExtractor is a test double for sourceExtractor. -type mockSourceExtractor struct { - extractFn func(r requestView) (Extraction, error) - nameValue string - sourceValue Source -} - -func (m *mockSourceExtractor) extract(r requestView) (Extraction, error) { - return m.extractFn(r) -} - -func (m *mockSourceExtractor) name() string { - return m.nameValue -} - -func (m *mockSourceExtractor) sourceInfo() Source { - return m.sourceValue -} - -func TestChainedSource_ReturnsFirstSuccess(t *testing.T) { - wantIP := netip.MustParseAddr("1.2.3.4") - wantSource := SourceXForwardedFor - - first := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - return Extraction{IP: wantIP, Source: wantSource}, nil - }, - nameValue: "first", - } - second := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - t.Fatal("second source should not be called") - return Extraction{}, nil - }, - nameValue: "second", - } - - chain := newChainedSource(nil, first, second) - result, err := chain.extract(requestView{}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result.IP != wantIP { - t.Errorf("IP = %v, want %v", result.IP, wantIP) - } - if result.Source != wantSource { - t.Errorf("Source = %v, want %v", result.Source, wantSource) - } -} - -func TestChainedSource_SkipsNonTerminalErrors(t *testing.T) { - wantIP := netip.MustParseAddr("5.6.7.8") - nonTerminal := errors.New("not terminal") - - first := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - return Extraction{}, nonTerminal - }, - nameValue: "first", - } - second := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - return Extraction{IP: wantIP, Source: SourceRemoteAddr}, nil - }, - nameValue: "second", - } - - chain := newChainedSource(sourceIsTerminalError, first, second) - result, err := chain.extract(requestView{}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result.IP != wantIP { - t.Errorf("IP = %v, want %v", result.IP, wantIP) - } -} - -func TestChainedSource_SkipsErrSourceUnavailable(t *testing.T) { - wantIP := netip.MustParseAddr("10.0.0.1") - - unavailable := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - return Extraction{}, &ExtractionError{Err: ErrSourceUnavailable, Source: SourceXForwardedFor} - }, - nameValue: "unavailable", - } - fallback := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - return Extraction{IP: wantIP, Source: SourceRemoteAddr}, nil - }, - nameValue: "fallback", - } - - chain := newChainedSource(sourceIsTerminalError, unavailable, fallback) - result, err := chain.extract(requestView{}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result.IP != wantIP { - t.Errorf("IP = %v, want %v", result.IP, wantIP) - } -} - -func TestChainedSource_ContextCanceledIsTerminal(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - first := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - return Extraction{}, context.Canceled - }, - nameValue: "first", - } - second := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - t.Fatal("second source should not be called after terminal error") - return Extraction{}, nil - }, - nameValue: "second", - } - - chain := newChainedSource(sourceIsTerminalError, first, second) - _, err := chain.extract(requestView{ctx: ctx}) - if err == nil { - t.Fatal("expected error, got nil") - } - if !errors.Is(err, context.Canceled) { - t.Errorf("error = %v, want context.Canceled", err) - } -} - -func TestChainedSource_TerminalErrorStopsChain(t *testing.T) { - terminalErr := &ExtractionError{Err: ErrUntrustedProxy, Source: SourceXForwardedFor} - - first := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - return Extraction{}, terminalErr - }, - nameValue: "first", - } - second := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - t.Fatal("second source should not be called after terminal error") - return Extraction{}, nil - }, - nameValue: "second", - } - - chain := newChainedSource(sourceIsTerminalError, first, second) - _, err := chain.extract(requestView{}) - if err == nil { - t.Fatal("expected error, got nil") - } - if !errors.Is(err, ErrUntrustedProxy) { - t.Errorf("error = %v, want ErrUntrustedProxy", err) - } -} - -func TestChainedSource_AllFailReturnsLastError(t *testing.T) { - err1 := &ExtractionError{Err: ErrSourceUnavailable, Source: SourceXForwardedFor} - err2 := &ExtractionError{Err: ErrSourceUnavailable, Source: SourceRemoteAddr} - - first := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - return Extraction{}, err1 - }, - nameValue: "first", - } - second := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - return Extraction{}, err2 - }, - nameValue: "second", - } - - chain := newChainedSource(sourceIsTerminalError, first, second) - _, err := chain.extract(requestView{}) - if !errors.Is(err, err2) { - t.Errorf("error = %v, want %v (last error)", err, err2) - } -} - -func TestChainedSource_Name(t *testing.T) { - a := &mockSourceExtractor{nameValue: "alpha"} - b := &mockSourceExtractor{nameValue: "beta"} - chain := newChainedSource(nil, a, b) - - want := "chained[alpha,beta]" - if got := chain.name(); got != want { - t.Errorf("name() = %q, want %q", got, want) - } -} - -func TestChainedSource_SourceInfo(t *testing.T) { - chain := newChainedSource(nil, &mockSourceExtractor{nameValue: "a"}) - got := chain.sourceInfo() - if got.valid() { - t.Errorf("sourceInfo() should return invalid Source, got %v", got) - } -} - -func TestChainedSource_ContextCanceledBeforeSecondSource(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - - firstCalled := false - secondCalled := false - first := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - firstCalled = true - cancel() // cancel context after first source runs - return Extraction{}, &ExtractionError{Err: ErrSourceUnavailable, Source: SourceXForwardedFor} - }, - nameValue: "first", - } - second := &mockSourceExtractor{ - extractFn: func(r requestView) (Extraction, error) { - secondCalled = true - return Extraction{}, nil - }, - nameValue: "second", - } - - chain := newChainedSource(sourceIsTerminalError, first, second) - _, err := chain.extract(requestView{ctx: ctx}) - if err == nil { - t.Fatal("expected error for cancelled context") - } - if !errors.Is(err, context.Canceled) { - t.Errorf("error = %v, want context.Canceled", err) - } - if !firstCalled { - t.Error("first source should have been called") - } - if secondCalled { - t.Error("second source should not be called after context cancellation") - } -} diff --git a/source_compile.go b/source_compile.go deleted file mode 100644 index 1c01863..0000000 --- a/source_compile.go +++ /dev/null @@ -1,45 +0,0 @@ -package clientip - -type sourceExtractor interface { - extract(r requestView) (Extraction, error) - name() string - sourceInfo() Source -} - -type sourceExtractorKind uint8 - -const ( - sourceExtractorKindForwarded sourceExtractorKind = iota + 1 - sourceExtractorKindXForwardedFor - sourceExtractorKindSingleHeader - sourceExtractorKindRemoteAddr -) - -type sourceSpec struct { - kind sourceExtractorKind - source Source - headerName string -} - -type sourceExecuteFunc func(requestView, sourceSpec) (Extraction, error) - -type compiledSource struct { - spec sourceSpec - execute sourceExecuteFunc -} - -func compileSource(spec sourceSpec, execute sourceExecuteFunc) sourceExtractor { - return &compiledSource{spec: spec, execute: execute} -} - -func (s *compiledSource) extract(r requestView) (Extraction, error) { - return s.execute(r, s.spec) -} - -func (s *compiledSource) name() string { - return s.spec.source.String() -} - -func (s *compiledSource) sourceInfo() Source { - return s.spec.source -} diff --git a/source_execution.go b/source_execution.go index d1b5020..dd75d54 100644 --- a/source_execution.go +++ b/source_execution.go @@ -4,169 +4,57 @@ import ( "context" "errors" "fmt" - "net/textproto" ) -func (e *Extractor) buildSourceChain(cfg *config) sourceExtractor { - sources := make([]sourceExtractor, 0, len(cfg.sourcePriority)) - for _, configuredSource := range cfg.sourcePriority { - spec := compileSpecFromSource(configuredSource) - executor := e.compileExecutor(spec, configuredSource) - sources = append(sources, compileSource(spec, executor)) +func (e *Extractor) extractChainSource( + r requestView, + source *configuredSource, + chainTooLongMessage string, + untrustedProxyMessage string, + handleParseError func(error), +) (Extraction, error) { + result, failure, err := source.chain.extract(r, source.source) + if err != nil { + e.handleChainError(r, source.source, err, chainTooLongMessage, handleParseError) + return Extraction{}, err } - - if len(sources) == 1 { - return sources[0] + if failure != nil { + if failure.kind == failureSourceUnavailable { + return Extraction{}, source.unavailableErr + } + return Extraction{}, e.adaptChainFailure(r, source.source, failure, untrustedProxyMessage) } - return newChainedSource(sourceIsTerminalError, sources...) + e.config.metrics.RecordExtractionSuccess(source.name) + return result, nil } -func compileSpecFromSource(source Source) sourceSpec { - source = canonicalSource(source) - spec := sourceSpec{source: source} - - switch source.kind { - case sourceForwarded: - spec.kind = sourceExtractorKindForwarded - spec.headerName = "Forwarded" - case sourceXForwardedFor: - spec.kind = sourceExtractorKindXForwardedFor - spec.headerName = "X-Forwarded-For" - case sourceRemoteAddr: - spec.kind = sourceExtractorKindRemoteAddr - default: - spec.kind = sourceExtractorKindSingleHeader - headerName, _ := source.headerKey() - spec.headerName = textproto.CanonicalMIMEHeaderKey(headerName) +func (e *Extractor) extractSingleHeaderSource(r requestView, source *configuredSource) (Extraction, error) { + result, failure := source.single.extract(r, source.source) + if failure != nil { + if failure.kind == failureSourceUnavailable { + return Extraction{}, source.unavailableErr + } + return Extraction{}, e.adaptSingleHeaderFailure(r, source.source, failure) } - return spec + e.config.metrics.RecordExtractionSuccess(source.name) + return result, nil } -func (e *Extractor) compileExecutor(spec sourceSpec, configuredSource Source) sourceExecuteFunc { - source := canonicalSource(configuredSource) - // Pre-compute source name string once to avoid per-call allocations - // from normalizeSourceName (strings.ToLower + ReplaceAll). - sourceName := source.String() - // Pre-allocate the source-unavailable error once per source to avoid - // allocating on every fallback miss in multi-source chains. - sourceUnavailableErr := &ExtractionError{Err: ErrSourceUnavailable, Source: source} - - switch spec.kind { - case sourceExtractorKindForwarded: - ce := chainExtractor{policy: chainPolicy{ - headerName: "Forwarded", - parseValues: func(values []string) ([]string, error) { - parts, err := parseForwardedValues(values, e.config.maxChainLength) - if err != nil { - return nil, adaptForwardedParseError(err, source, e) - } - return parts, nil - }, - parseClientIP: parseChainIP, - clientIP: e.clientIP, - trustedProxy: e.proxy, - selection: e.config.chainSelection, - collectDebugInfo: e.config.debugMode, - untrustedChainSep: ", ", - }} - return func(r requestView, _ sourceSpec) (Extraction, error) { - result, failure, err := ce.extract(r, source) - if err != nil { - e.handleChainError(r, source, err, - "Forwarded chain exceeds configured maximum length", - func(err error) { - if !errors.Is(err, ErrInvalidForwardedHeader) { - return - } - e.config.metrics.RecordSecurityEvent(SecurityEventMalformedForwarded) - e.logSecurityWarning(r, source, SecurityEventMalformedForwarded, "malformed Forwarded header received", "parse_error", err.Error()) - }, - ) - return Extraction{}, err - } - if failure != nil { - if failure.kind == failureSourceUnavailable { - return Extraction{}, sourceUnavailableErr - } - return Extraction{}, e.adaptChainFailure(r, source, failure, "request received from untrusted proxy while Forwarded is present") - } - e.config.metrics.RecordExtractionSuccess(sourceName) - return result, nil - } - - case sourceExtractorKindXForwardedFor: - ce := chainExtractor{policy: chainPolicy{ - headerName: "X-Forwarded-For", - parseValues: func(values []string) ([]string, error) { - parts, err := parseXFFValues(values, e.config.maxChainLength) - if err != nil { - return nil, adaptXFFParseError(err, source, e) - } - return parts, nil - }, - parseClientIP: parseIP, - clientIP: e.clientIP, - trustedProxy: e.proxy, - selection: e.config.chainSelection, - collectDebugInfo: e.config.debugMode, - untrustedChainSep: ", ", - }} - return func(r requestView, _ sourceSpec) (Extraction, error) { - result, failure, err := ce.extract(r, source) - if err != nil { - e.handleChainError(r, source, err, - "X-Forwarded-For chain exceeds configured maximum length", - nil, - ) - return Extraction{}, err - } - if failure != nil { - if failure.kind == failureSourceUnavailable { - return Extraction{}, sourceUnavailableErr - } - return Extraction{}, e.adaptChainFailure(r, source, failure, "request received from untrusted proxy while X-Forwarded-For is present") - } - e.config.metrics.RecordExtractionSuccess(sourceName) - return result, nil - } - - case sourceExtractorKindRemoteAddr: - re := remoteAddrExtractor{clientIPPolicy: e.clientIP} - return func(r requestView, _ sourceSpec) (Extraction, error) { - result, failure := re.extract(r.remoteAddr(), source) - if failure != nil { - if failure.kind == failureSourceUnavailable { - return Extraction{}, sourceUnavailableErr - } - e.recordInvalidClientIPDisposition(failure.clientIPDisposition) - e.config.metrics.RecordExtractionFailure(sourceName) - return Extraction{}, adaptRemoteAddrFailure(failure, source) - } - e.config.metrics.RecordExtractionSuccess(sourceName) - return result, nil - } - - default: - headerName := spec.headerName - she := singleHeaderExtractor{policy: singleHeaderPolicy{ - headerName: headerName, - clientIP: e.clientIP, - trustedProxy: e.proxy, - }} - return func(r requestView, _ sourceSpec) (Extraction, error) { - result, failure := she.extract(r, source) - if failure != nil { - if failure.kind == failureSourceUnavailable { - return Extraction{}, sourceUnavailableErr - } - return Extraction{}, e.adaptSingleHeaderFailure(r, source, failure) - } - e.config.metrics.RecordExtractionSuccess(sourceName) - return result, nil +func (e *Extractor) extractRemoteAddrSource(r requestView, source *configuredSource) (Extraction, error) { + result, failure := source.remote.extract(r.remoteAddr(), source.source) + if failure != nil { + if failure.kind == failureSourceUnavailable { + return Extraction{}, source.unavailableErr } + e.recordInvalidClientIPDisposition(failure.clientIPDisposition) + e.config.metrics.RecordExtractionFailure(source.name) + return Extraction{}, adaptRemoteAddrFailure(failure, source.source) } + + e.config.metrics.RecordExtractionSuccess(source.name) + return result, nil } func sourceIsTerminalError(err error) bool { From d3afd128ba9800b0dd9670c8f89429692abf0e39 Mon Sep 17 00:00:00 2001 From: Thomas de Jong Date: Thu, 23 Apr 2026 21:04:47 +0200 Subject: [PATCH 5/5] chore: update docs --- CHANGELOG.md | 3 +- README.md | 93 ++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 89 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e7e716..177435f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on Keep a Changelog and this project follows Semantic Versio - `Resolver`, `ResolverConfig`, `PreferredFallback`, and `Resolution` as the request-scoped API for strict and preferred client IP resolution. - `StrictResolutionFromContext` and `PreferredResolutionFromContext` for reusing cached resolver state across middleware. +- `Resolver.ResolveInputStrict` and `Resolver.ResolveInputPreferred` for framework-agnostic request-scoped resolution. - `Input`, `ExtractInput`, and `ExtractInputAddr` for framework-agnostic request handling. - `ParseRemoteAddr` helper. - `ClassifyError`, `ResultKind`, and result classification constants for coarse-grained policy handling. @@ -26,7 +27,7 @@ The format is based on Keep a Changelog and this project follows Semantic Versio - **BREAKING:** Preferred fallback is explicit resolver behavior with `Resolution.FallbackUsed`; fallback does not emit separate metrics or log events in this phase. - **BREAKING:** `SourceStaticFallback` remains public but is resolver-result-only; it cannot be used in `Config.Sources`. - Presets remain `Config` helpers and now document resolver-oriented usage more clearly. -- Prometheus integration is constructor-based: build metrics with `prometheus.New()` or `prometheus.NewWithRegisterer(...)` and assign them through `Config.Metrics`. +- Prometheus integration on `main` is constructor-based: build metrics with `prometheus.New()` or `prometheus.NewWithRegisterer(...)` and assign them through `Config.Metrics`. The published adapter module remains pinned to root `v0.0.6` until the matching adapter release is tagged. - `X-Forwarded-For` chain extraction again accepts the host:port and quoted forms already supported by `parseIP`, while `Forwarded` stays strict and now rejects present-but-empty values plus empty delimiter-created elements/parameters as malformed. - Internal orchestration now sits behind `internal/engine` and concrete source execution behind `internal/source`. diff --git a/README.md b/README.md index dcea79c..c734bf8 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,28 @@ Secure client IP extraction for `net/http` and framework-agnostic request inputs This project is pre-`v1.0.0` and still before `v0.1.0`, so public APIs may change as the package evolves. Any breaking changes are called out in `CHANGELOG.md`. +This README tracks the current `main` branch rather than the latest tagged release. + +## Contents + +- [Install](#install) +- [Choose the API](#choose-the-api) +- [Common Setups](#common-setups) +- [Rules To Remember](#rules-to-remember) +- [Quick Start](#quick-start) +- [Preferred Resolution And Fallback](#preferred-resolution-and-fallback) +- [Framework-Agnostic Input](#framework-agnostic-input) +- [Presets](#presets) +- [Config](#config) +- [Low-Level Extraction](#low-level-extraction) +- [Errors](#errors) +- [Logging](#logging) +- [Prometheus Metrics](#prometheus-metrics) +- [Security Guidance](#security-guidance) +- [Compatibility](#compatibility) +- [Performance](#performance) +- [Maintainer Notes (Multi-Module)](#maintainer-notes-multi-module) +- [License](#license) ## Install @@ -23,15 +45,69 @@ Optional Prometheus adapter: go get github.com/abczzz13/clientip/prometheus ``` +> Version note: the published adapter module is still pinned to `github.com/abczzz13/clientip v0.0.6` until the next coordinated release. This README documents the current `main`-branch API, so the Prometheus snippets below describe the upcoming adapter wiring after that release lands. + ## Choose the API -- `Resolver` is the primary API. Use it when middleware, handlers, or framework adapters need to resolve the client IP once and reuse the result on the same request. -- `Extractor` is the low-level strict primitive. Use it when you only need one extraction call and do not need request-scoped caching or preferred fallback. -- `Input` is the framework-agnostic carrier for non-`net/http` integrations. -- `ParseRemoteAddr` and `ClassifyError` are small helpers for explicit fallback and policy code. +Use this as a quick decision guide: + +| Need | Use | +| --- | --- | +| Security-sensitive or audit-oriented result | `Resolver.ResolveStrict` or `Extractor` | +| Best-effort operational IP with explicit fallback | `Resolver.ResolvePreferred` | +| Framework integration without `*http.Request` | `Input` with `Resolver` or `Extractor` | +| Parse `RemoteAddr` outside extraction | `ParseRemoteAddr` | +| Coarse policy branching on error categories | `ClassifyError` | Construct an `Extractor` once and reuse it. Build a `Resolver` on top when you want strict or preferred request-scoped resolution. +## Common Setups + +Most integrations start with a preset and only drop to a fully manual `Config` when the proxy topology is unusual. + +Direct app-to-client traffic: + +```go +extractor, err := clientip.New(clientip.PresetDirectConnection()) +``` + +Reverse proxy on the same host: + +```go +extractor, err := clientip.New(clientip.PresetLoopbackReverseProxy()) +``` + +Reverse proxy on a VM or private network: + +```go +extractor, err := clientip.New(clientip.PresetVMReverseProxy()) +``` + +Custom trusted header source: + +```go +extractor, err := clientip.New(clientip.Config{ + TrustedProxyPrefixes: clientip.LoopbackProxyPrefixes(), + Sources: []clientip.Source{ + clientip.HeaderSource("CF-Connecting-IP"), + clientip.SourceRemoteAddr, + }, +}) +``` + +Framework request input: + +Use `Input` when the framework does not hand you `*http.Request` directly. The same extractor and resolver rules still apply. + +## Rules To Remember + +- `Resolver` is the primary integration-facing API; `Extractor` is the lower-level strict primitive. +- Header-based sources require `TrustedProxyPrefixes`. +- Prefer a preset first, then tweak `Config` only when needed. +- Only configure one proxy-chain source at a time: `SourceForwarded` or `SourceXForwardedFor`. +- Preferred fallback is operationally useful, but not suitable for authorization, ACLs, or trust-boundary enforcement. +- `Input.Headers` must preserve repeated header lines as separate slice entries; merging them breaks duplicate detection and chain parsing semantics. + ## Quick Start Use `Resolver.ResolveStrict` for security-sensitive or audit-oriented decisions. @@ -180,6 +256,8 @@ Presets configure `Config`, not `ResolverConfig`. Preferred resolver fallback st `Config` stays flat in the current API. +Most callers should start from `PresetDirectConnection`, `PresetLoopbackReverseProxy`, or `PresetVMReverseProxy`, then adjust the returned `Config` if they need custom trust ranges, source order, or observability wiring. + Important fields: - `TrustedProxyPrefixes []netip.Prefix` @@ -316,6 +394,8 @@ Security event labels passed through `Metrics.RecordSecurityEvent(...)` are the Construct Prometheus metrics explicitly and pass them through `Config.Metrics`. +This constructor-based wiring reflects the current `main` branch. Tagged-release consumers of `github.com/abczzz13/clientip/prometheus` still get the older adapter release that depends on root `v0.0.6` until the next coordinated adapter tag is published. + ```go import clientipprom "github.com/abczzz13/clientip/prometheus" @@ -359,7 +439,8 @@ if err != nil { ## Compatibility - Core module (`github.com/abczzz13/clientip`) supports Go `1.21+`. -- Optional Prometheus adapter (`github.com/abczzz13/clientip/prometheus`) has a minimum Go version of `1.21`; CI currently validates consumer mode on Go `1.21.x` and `1.26.x`. +- Optional Prometheus adapter (`github.com/abczzz13/clientip/prometheus`) has a minimum Go version of `1.21`; CI currently validates consumer mode on Go `1.21.x` and `1.26.x` against the released root module until the next coordinated adapter release is tagged. +- The published adapter module is currently pinned to `github.com/abczzz13/clientip v0.0.6`; the constructor-based `Config.Metrics` wiring documented above becomes the tagged-release path once that coordinated release ships. - Prometheus client dependency in the adapter is pinned to `github.com/prometheus/client_golang v1.21.1`. ## Performance @@ -387,7 +468,7 @@ You can compare arbitrary files directly via `just bench-compare < - `prometheus/go.mod` intentionally does not use a local `replace` directive for `github.com/abczzz13/clientip`. - For local co-development, create an uncommitted workspace with `go work init . ./prometheus`. -- Validate the adapter as a consumer with `GOWORK=off go -C prometheus test ./...`. +- Validate the adapter as a consumer with `GOWORK=off go -C prometheus test ./...`; until the next coordinated release, this intentionally exercises the released root module `v0.0.6` instead of the unreleased workspace API. - `just` and CI validate the adapter in consumer mode by default (`GOWORK=off`); set `CLIENTIP_ADAPTER_GOWORK=auto` locally when you intentionally want workspace-mode adapter checks. - Release in this order: tag root module `vX.Y.Z`, bump `prometheus/go.mod` to that version, then tag adapter module `prometheus/vX.Y.Z`.