diff --git a/AGENTS.md b/AGENTS.md index 886b3b75..be8ca8be 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -47,6 +47,7 @@ Resilience and safety: - `commons/assert`: production-safe assertions with telemetry integration and domain predicates - `commons/safe`: panic-free math/regex/slice operations with error returns - `commons/security`: sensitive field detection and handling +- `commons/security/ssrf`: canonical SSRF validation — IP blocking (CIDR blocklist + stdlib predicates), hostname blocking (metadata endpoints, dangerous suffixes), URL validation, DNS-pinned resolution with TOCTOU elimination - `commons/errgroup`: goroutine coordination with panic recovery - `commons/certificate`: thread-safe TLS certificate manager with hot reload, PEM file loading, PKCS#8/PKCS#1/EC key support, and `tls.Config` integration @@ -137,6 +138,17 @@ Build and shell: - TLS integration: `TLSCertificate() tls.Certificate` (returns populated `tls.Certificate` struct); `GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error)` — suitable for assignment to `tls.Config.GetCertificate` for transparent hot-reload. - Package-level helper: `LoadFromFiles(certPath, keyPath string) (*x509.Certificate, crypto.Signer, error)` — validates without touching any manager state, useful for pre-flight checking before calling `Rotate`. +### SSRF validation (`commons/security/ssrf`) + +- `IsBlockedIP(net.IP)` and `IsBlockedAddr(netip.Addr)` for IP-level SSRF blocking. `IsBlockedAddr` is the core; `IsBlockedIP` delegates after conversion. +- `IsBlockedHostname(hostname)` for hostname-level blocking: localhost, cloud metadata endpoints, `.local`/`.internal`/`.cluster.local` suffixes. +- `BlockedPrefixes() []netip.Prefix` returns a copy of the canonical CIDR blocklist (8 ranges: this-network, CGNAT, IETF assignments, TEST-NET-1/2/3, benchmarking, reserved). +- `ValidateURL(ctx, rawURL, opts...)` validates scheme, hostname, and parsed IP without DNS. +- `ResolveAndValidate(ctx, rawURL, opts...) (*ResolveResult, error)` performs DNS-pinned validation. `ResolveResult` has `PinnedURL`, `Authority` (host:port for HTTP Host), `SNIHostname` (for TLS ServerName). +- Functional options: `WithHTTPSOnly()`, `WithAllowPrivateNetwork()`, `WithLookupFunc(fn)`, `WithAllowHostname(hostname)`. +- Sentinel errors: `ErrBlocked`, `ErrInvalidURL`, `ErrDNSFailed`. +- Both `commons/webhook` and `commons/net/http` delegate to this package — it is the single source of truth for SSRF blocking across all Lerian services. + ### Assertions (`commons/assert`) - `New(ctx, logger, component, operation) *Asserter` and return errors instead of panicking. @@ -238,6 +250,7 @@ Build and shell: - **Pointers:** `String()`, `Bool()`, `Time()`, `Int()`, `Int64()`, `Float64()`. - **Cron:** `Parse(expr) (Schedule, error)`; `Schedule.Next(t) (time.Time, error)`. - **Security:** `IsSensitiveField(name)`, `DefaultSensitiveFields()`, `DefaultSensitiveFieldsMap()`. +- **SSRF:** `IsBlockedIP()`, `IsBlockedAddr()`, `IsBlockedHostname()`, `BlockedPrefixes()`, `ValidateURL()`, `ResolveAndValidate()`. Single source of truth for all SSRF protection. Both `webhook` and `net/http` delegate here. - **Transaction:** `BuildIntentPlan()` + `ValidateBalanceEligibility()` + `ApplyPosting()` with typed `IntentPlan`, `Posting`, `LedgerTarget`. `ResolveOperation(pending, isSource, status) (Operation, error)`. - **Constants:** `SanitizeMetricLabel(value) string` for OTEL label safety. - **Certificate:** `NewManager(certPath, keyPath) (*Manager, error)`; `Rotate(cert, key)`, `TLSCertificate()`, `GetCertificateFunc()`; package-level `LoadFromFiles(certPath, keyPath)` for pre-flight validation. diff --git a/commons/certificate/certificate.go b/commons/certificate/certificate.go index 6906d57c..194780cc 100644 --- a/commons/certificate/certificate.go +++ b/commons/certificate/certificate.go @@ -151,9 +151,10 @@ func (m *Manager) Rotate(cert *x509.Certificate, key crypto.Signer, intermediate return nil } -// GetCertificate returns the current certificate, or nil if none is loaded. -// The returned pointer shares state with the manager; callers must treat it -// as read-only. Use [LoadFromFiles] + [Manager.Rotate] to replace it. +// GetCertificate returns a deep copy of the current certificate, or nil if +// none is loaded. The returned *x509.Certificate is an independent clone that +// callers may freely modify without affecting the manager's internal state. +// Use [LoadFromFiles] + [Manager.Rotate] to replace the managed certificate. func (m *Manager) GetCertificate() *x509.Certificate { if m == nil { return nil @@ -162,7 +163,7 @@ func (m *Manager) GetCertificate() *x509.Certificate { m.mu.RLock() defer m.mu.RUnlock() - return m.cert + return cloneCert(m.cert) } // GetSigner returns the current private key as a crypto.Signer, or nil if none is loaded. @@ -212,7 +213,10 @@ func (m *Manager) ExpiresAt() time.Time { } // DaysUntilExpiry returns the number of days until the certificate expires. -// Returns -1 if no certificate is loaded. +// It returns -1 when no certificate is loaded (nil receiver or no certificate +// configured via [NewManager]). Otherwise it returns the number of days until +// expiry, which may be negative for already-expired certificates (e.g. -3 +// means the certificate expired 3 days ago). func (m *Manager) DaysUntilExpiry() int { if m == nil { return -1 @@ -228,8 +232,9 @@ func (m *Manager) DaysUntilExpiry() int { // TLSCertificate returns a [tls.Certificate] built from the currently loaded // certificate chain and private key. Returns an empty [tls.Certificate] if no -// certificate is loaded. The Leaf field shares state with the manager; callers -// should treat it as read-only. +// certificate is loaded. Both the Certificate [][]byte chain and the Leaf are +// deep copies, so callers never receive references aliasing internal state. +// Safe to call on a nil receiver (returns an empty [tls.Certificate]). func (m *Manager) TLSCertificate() tls.Certificate { if m == nil { return tls.Certificate{} @@ -252,7 +257,7 @@ func (m *Manager) TLSCertificate() tls.Certificate { return tls.Certificate{ Certificate: chainCopy, PrivateKey: m.signer, - Leaf: m.cert, + Leaf: cloneCert(m.cert), } } @@ -388,3 +393,25 @@ func publicKeysMatch(certPublicKey, signerPublicKey any) bool { return bytes.Equal(certDER, signerDER) } + +// cloneCert returns a deep copy of cert by re-parsing its DER encoding. +// Returns nil when cert is nil. +func cloneCert(cert *x509.Certificate) *x509.Certificate { + if cert == nil { + return nil + } + + raw := make([]byte, len(cert.Raw)) + copy(raw, cert.Raw) + + // ParseCertificate is the canonical way to deep-copy an x509.Certificate. + // Errors here are unexpected (the DER was already parsed once), but we + // return nil rather than panicking to stay consistent with the package's + // nil-safe contract. + clone, err := x509.ParseCertificate(raw) + if err != nil { + return nil + } + + return clone +} diff --git a/commons/certificate/certificate_test.go b/commons/certificate/certificate_test.go index dba18d66..8b5ac909 100644 --- a/commons/certificate/certificate_test.go +++ b/commons/certificate/certificate_test.go @@ -860,20 +860,11 @@ func TestNilManager_SubTests(t *testing.T) { assert.Equal(t, -1, m.DaysUntilExpiry()) }) - t.Run("TLSCertificate returns empty", func(t *testing.T) { + t.Run("TLSCertificate returns empty on nil receiver", func(t *testing.T) { t.Parallel() - // TLSCertificate checks m == nil after acquiring the lock which it - // can't do on a nil receiver — production code checks m == nil inside - // the method body after the RLock, but since m is nil the lock call - // itself will panic. The method guards with `if m == nil` but the - // guard is AFTER the RLock. Let's verify by calling it safely. - // Actually the method body is: mu.RLock then `if m == nil` — a nil - // pointer dereference would occur. We document that TLSCertificate - // is NOT nil-receiver-safe (unlike the other methods) by asserting - // the returned value from a non-nil empty manager instead. - empty, emptyErr := NewManager("", "") - require.NoError(t, emptyErr) - tlsCert := empty.TLSCertificate() + // TLSCertificate is nil-receiver-safe: the nil guard precedes the + // RLock call, consistent with every other Manager method. + tlsCert := m.TLSCertificate() assert.Equal(t, tls.Certificate{}, tlsCert) }) } diff --git a/commons/dlq/consumer.go b/commons/dlq/consumer.go index e3c21000..55ed95bf 100644 --- a/commons/dlq/consumer.go +++ b/commons/dlq/consumer.go @@ -106,11 +106,15 @@ func WithBatchSize(n int) ConsumerOption { } } -// WithSources sets the DLQ source queue names to consume. +// WithSources sets the DLQ source queue names to consume. The input slice is +// cloned so that subsequent mutations by the caller do not race with the +// consumer's read loop. func WithSources(sources ...string) ConsumerOption { return func(c *Consumer) { if len(sources) > 0 { - c.cfg.Sources = sources + cloned := make([]string, len(sources)) + copy(cloned, sources) + c.cfg.Sources = cloned } } } diff --git a/commons/dlq/handler.go b/commons/dlq/handler.go index 63a92232..0b9a4188 100644 --- a/commons/dlq/handler.go +++ b/commons/dlq/handler.go @@ -148,52 +148,18 @@ func (h *Handler) Enqueue(ctx context.Context, msg *FailedMessage) error { return ErrNilHandler } - if msg == nil { - return errors.New("dlq: enqueue: nil message") - } - - if msg.Source == "" { - return errors.New("dlq: enqueue: source must not be empty") - } - - if err := validateKeySegment("source", msg.Source); err != nil { + if err := validateEnqueueMessage(msg); err != nil { return err } ctx, span := h.tracer.Start(ctx, "dlq.enqueue") defer span.End() - // Only stamp CreatedAt and MaxRetries on initial enqueue (zero-valued). - // Re-enqueue paths (consumer retry-failed, not-yet-ready, prune) pass - // messages that already carry the original values; overwriting them would - // permanently lose the original failure timestamp and retry budget. - initialEnqueue := msg.CreatedAt.IsZero() - if initialEnqueue { - msg.CreatedAt = time.Now().UTC() - } - - if msg.MaxRetries == 0 { - msg.MaxRetries = h.maxRetries - } - - ctxTenant := tmcore.GetTenantIDContext(ctx) + h.stampInitialEnqueue(msg) - effectiveTenant := msg.TenantID - if effectiveTenant == "" { - effectiveTenant = ctxTenant - msg.TenantID = effectiveTenant - } - - if effectiveTenant != "" && ctxTenant != "" && effectiveTenant != ctxTenant { - return fmt.Errorf("dlq: enqueue: tenant mismatch between message (%s) and context (%s)", effectiveTenant, ctxTenant) - } - - // Recalculate NextRetryAt only on initial enqueue. On re-enqueue the - // consumer has already incremented RetryCount and the caller is - // responsible for timing; we preserve their NextRetryAt or let the - // backoff be recalculated by the consumer path that sets RetryCount. - if initialEnqueue && msg.RetryCount < msg.MaxRetries { - msg.NextRetryAt = msg.CreatedAt.Add(backoffDuration(msg.RetryCount)) + effectiveTenant, err := h.resolveAndValidateTenant(ctx, msg) + if err != nil { + return err } data, err := json.Marshal(msg) @@ -223,6 +189,74 @@ func (h *Handler) Enqueue(ctx context.Context, msg *FailedMessage) error { return nil } +// validateEnqueueMessage performs pre-flight validation on the message before +// any state mutation or tracing begins. +func validateEnqueueMessage(msg *FailedMessage) error { + if msg == nil { + return errors.New("dlq: enqueue: nil message") + } + + if msg.Source == "" { + return errors.New("dlq: enqueue: source must not be empty") + } + + return validateKeySegment("source", msg.Source) +} + +// stampInitialEnqueue sets CreatedAt, MaxRetries, and NextRetryAt on messages +// that are being enqueued for the first time (CreatedAt is zero). +// Re-enqueue paths (consumer retry-failed, not-yet-ready, prune) pass messages +// that already carry the original values; overwriting them would permanently +// lose the original failure timestamp and retry budget. +func (h *Handler) stampInitialEnqueue(msg *FailedMessage) { + initialEnqueue := msg.CreatedAt.IsZero() + if initialEnqueue { + msg.CreatedAt = time.Now().UTC() + } + + if msg.MaxRetries <= 0 { + msg.MaxRetries = h.maxRetries + } + + // Recalculate NextRetryAt only on initial enqueue. On re-enqueue the + // consumer has already incremented RetryCount and the caller is + // responsible for timing; we preserve their NextRetryAt or let the + // backoff be recalculated by the consumer path that sets RetryCount. + if initialEnqueue && msg.RetryCount < msg.MaxRetries { + msg.NextRetryAt = msg.CreatedAt.Add(backoffDuration(msg.RetryCount)) + } +} + +// resolveAndValidateTenant determines the effective tenant ID for the message by +// reconciling the message's TenantID with the tenant from context. It validates +// that they match when both are present, and validates the tenant as a safe Redis +// key segment. Returns the effective tenant ID. +func (h *Handler) resolveAndValidateTenant(ctx context.Context, msg *FailedMessage) (string, error) { + ctxTenant := tmcore.GetTenantIDContext(ctx) + + effectiveTenant := msg.TenantID + if effectiveTenant == "" { + effectiveTenant = ctxTenant + msg.TenantID = effectiveTenant + } + + if effectiveTenant != "" && ctxTenant != "" && effectiveTenant != ctxTenant { + return "", fmt.Errorf("dlq: enqueue: tenant mismatch between message (%s) and context (%s)", effectiveTenant, ctxTenant) + } + + // Validate the effective tenant before using it to construct a Redis key. + // This prevents invalid tenant IDs from silently falling back to the global + // (non-tenant) key inside tenantScopedKeyForTenant, which would mix + // tenant-scoped messages into the global queue. + if effectiveTenant != "" { + if err := validateKeySegment("tenantID", effectiveTenant); err != nil { + return "", fmt.Errorf("dlq: enqueue: %w", err) + } + } + + return effectiveTenant, nil +} + // logEnqueueFallback logs message metadata when Redis is unreachable. The // payload is redacted to prevent PII leakage into log aggregators. func (h *Handler) logEnqueueFallback(ctx context.Context, key string, msg *FailedMessage, err error) { diff --git a/commons/net/http/idempotency/doc.go b/commons/net/http/idempotency/doc.go index dfd0c50c..a4bd92a6 100644 --- a/commons/net/http/idempotency/doc.go +++ b/commons/net/http/idempotency/doc.go @@ -6,37 +6,22 @@ // requests may execute more than once. Callers that require strict at-most-once // guarantees must pair this middleware with application-level safeguards. // -// The middleware uses the X-Idempotency request header combined with the tenant ID -// (from tenant-manager context) to form a composite Redis key. When a tenant ID is -// present, keys are scoped per-tenant to prevent cross-tenant collisions. When no -// tenant is in context (e.g., non-tenant-scoped routes), keys are still valid but -// are shared across the global namespace. Duplicate requests receive the original -// response with the X-Idempotency-Replayed header set to "true". +// # Key composition // -// # Quick start -// -// conn, err := redis.New(ctx, cfg) -// if err != nil { -// log.Fatal(err) -// } -// idem := idempotency.New(conn) -// app.Post("/orders", idem.Check(), createOrderHandler) -// -// # Behavior +// The middleware uses the X-Idempotency request header ([constants.IdempotencyKey]) +// combined with the tenant ID (from tenant-manager context via +// [tmcore.GetTenantIDContext]) to form a composite Redis key: // -// - GET/OPTIONS requests pass through (idempotency is irrelevant for reads). -// - Absent header: request proceeds normally (idempotency is opt-in). -// - Header exceeds MaxKeyLength (default 256): request rejected with 400. -// - Duplicate key: cached response replayed with X-Idempotency-Replayed: true. -// - Redis failure: request proceeds (fail-open for availability). -// - Handler success: response cached; handler failure: key deleted (client can retry). +// : // -// # Redis key namespace convention +// When a tenant ID is present, keys are scoped per-tenant to prevent +// cross-tenant collisions. When no tenant is in context (e.g., non-tenant-scoped +// routes), the key becomes : in the global namespace. // -// Keys follow the pattern: : -// with a companion response key at ::response. +// A companion response key at ::response stores +// the cached response body and headers for replay. // -// The default prefix is "idempotency:" and can be overridden via WithKeyPrefix. +// The default prefix is "idempotency:" and can be overridden via [WithKeyPrefix]. // This namespacing convention is consistent with other lib-commons packages that // use Redis (e.g., rate limiting uses "ratelimit::..."). Per-tenant // isolation is enforced by embedding the tenant ID into the key rather than @@ -44,7 +29,44 @@ // implementation topology-agnostic (standalone, sentinel, and cluster all behave // identically with this approach). // +// # Quick start +// +// conn, err := redis.New(ctx, cfg) +// if err != nil { +// return err +// } +// idem := idempotency.New(conn) +// app.Post("/orders", idem.Check(), createOrderHandler) +// +// # Behavior branches +// +// The [Middleware.Check] handler evaluates requests through the following +// branches in order: +// +// - GET, HEAD, and OPTIONS requests pass through unconditionally — idempotency +// is not enforced for safe/idempotent HTTP methods. +// - Absent X-Idempotency header: request proceeds normally (idempotency is +// opt-in per request). +// - Header exceeds [WithMaxKeyLength] (default 256): request is passed to the +// configured [WithRejectedHandler]. When no custom handler is set, a 400 JSON +// response with code "VALIDATION_ERROR" is returned. +// - Redis unavailable (GetClient, SetNX, or Get failures): request proceeds +// without idempotency enforcement (fail-open), logged at WARN level. +// - Duplicate key with cached response: the original response is replayed +// faithfully — status code, headers (including Location, ETag, Set-Cookie), +// content type, and body — with [constants.IdempotencyReplayed] set to "true". +// - Duplicate key still in "processing" state (in-flight): 409 Conflict with +// code "IDEMPOTENCY_CONFLICT" is returned. +// - Duplicate key in "complete" state but no cached body (e.g., body exceeded +// [WithMaxBodyCache]): 200 OK with code "IDEMPOTENT" and detail "request +// already processed" is returned. +// - Handler success: response (status, headers, body) is cached via a Redis +// pipeline and the key is marked "complete". +// - Handler failure: both the lock key and response key are deleted so the +// client can retry with the same idempotency key. +// // # Nil safety // -// A nil *Middleware returns a pass-through handler from Check(). +// [New] returns nil when conn is nil. A nil [*Middleware] returns a pass-through +// handler from [Middleware.Check]. package idempotency diff --git a/commons/net/http/idempotency/idempotency.go b/commons/net/http/idempotency/idempotency.go index b4714e2d..06265154 100644 --- a/commons/net/http/idempotency/idempotency.go +++ b/commons/net/http/idempotency/idempotency.go @@ -2,7 +2,10 @@ package idempotency import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" + "errors" "fmt" "net/http" "time" @@ -21,10 +24,15 @@ const ( ) // cachedResponse stores the full HTTP response for idempotent replay. +// Body is stored as raw bytes (base64-encoded in JSON) so that binary and +// non-UTF-8 payloads survive a marshal/unmarshal round-trip. Headers preserves +// response headers that must be faithfully replayed (e.g., Location, ETag, +// Set-Cookie). type cachedResponse struct { - StatusCode int `json:"status_code"` - ContentType string `json:"content_type"` - Body string `json:"body"` + StatusCode int `json:"status_code"` + ContentType string `json:"content_type"` + Body []byte `json:"body"` + Headers map[string][]string `json:"headers,omitempty"` } // Option configures the idempotency middleware. @@ -146,6 +154,15 @@ func (m *Middleware) Check() fiber.Handler { return m.handle } +// redactKey returns a truncated SHA-256 hash of a Redis key for safe logging. +// Idempotency keys are client-controlled and tenant-scoped, so logging them +// verbatim would emit high-cardinality identifiers and potentially leak tenant +// or client information during incidents. +func redactKey(key string) string { + h := sha256.Sum256([]byte(key)) + return hex.EncodeToString(h[:8]) +} + func (m *Middleware) handle(c *fiber.Ctx) error { // Idempotency only applies to mutating methods. if c.Method() == fiber.MethodGet || c.Method() == fiber.MethodOptions || c.Method() == fiber.MethodHead { @@ -219,17 +236,63 @@ func (m *Middleware) handleDuplicate( key, responseKey string, ) error { // Read the current key value to distinguish in-flight from completed. - keyValue, _ := client.Get(ctx, key).Result() + keyValue, keyErr := client.Get(ctx, key).Result() + if keyErr != nil && !errors.Is(keyErr, redis.Nil) { + // Unexpected Redis error (timeout, connection failure) — fail open. + m.logger.Log(ctx, log.LevelWarn, + "idempotency: failed to read key state, failing open", + log.String("key_hash", redactKey(key)), log.Err(keyErr), + ) + + return c.Next() + } + + // The marker has vanished between the SetNX (which saw it) and this Get. + // This happens when the original request failed and deleted the key, or + // the TTL expired in the narrow window. Fail open so the duplicate can + // be retried rather than returning a false "already processed" response. + if errors.Is(keyErr, redis.Nil) { + return c.Next() + } // Try to replay the cached response (true idempotency). cached, cacheErr := client.Get(ctx, responseKey).Result() - if cacheErr == nil && cached != "" { + + switch { + case cacheErr != nil && !errors.Is(cacheErr, redis.Nil): + // Unexpected Redis error reading cached response — fail open. + m.logger.Log(ctx, log.LevelWarn, + "idempotency: failed to read cached response, failing open", + log.String("key_hash", redactKey(responseKey)), log.Err(cacheErr), + ) + + return c.Next() + case cacheErr == nil && cached != "": var resp cachedResponse - if unmarshalErr := json.Unmarshal([]byte(cached), &resp); unmarshalErr == nil { + if unmarshalErr := json.Unmarshal([]byte(cached), &resp); unmarshalErr != nil { + // Cache entry is corrupt or written by an incompatible version. + // Log a warning so operators can investigate, then fall through + // to the generic "already processed" response (fail-open). + m.logger.Log(ctx, log.LevelWarn, + "idempotency: failed to unmarshal cached response, falling through to generic reply", + log.String("key_hash", redactKey(responseKey)), log.Err(unmarshalErr), + ) + } else { + // Replay persisted headers first so the caller sees + // Location, ETag, Set-Cookie, etc. exactly as sent originally. + // Use Header.Add (not c.Set) so multi-value headers such as + // Set-Cookie are appended rather than silently overwritten. + for name, values := range resp.Headers { + for _, v := range values { + c.Response().Header.Add(name, v) + } + } + c.Set(chttp.IdempotencyReplayed, "true") c.Set("Content-Type", resp.ContentType) - return c.Status(resp.StatusCode).SendString(resp.Body) + // Send (not SendString) preserves binary/non-UTF-8 bodies. + return c.Status(resp.StatusCode).Send(resp.Body) } } @@ -268,10 +331,27 @@ func (m *Middleware) saveResult( pipe := client.Pipeline() if len(body) <= m.maxBodyCache { + // Capture response headers for faithful replay. + headers := make(map[string][]string) + + for hdrKey, value := range c.Response().Header.All() { + name := string(hdrKey) + // Skip headers managed by the middleware itself and + // transfer-encoding / content-length which Fiber sets on send. + switch name { + case "Content-Type", "Content-Length", "Transfer-Encoding", + chttp.IdempotencyReplayed: + continue + } + + headers[name] = append(headers[name], string(value)) + } + resp := cachedResponse{ StatusCode: c.Response().StatusCode(), ContentType: string(c.Response().Header.ContentType()), - Body: string(body), + Body: body, + Headers: headers, } if data, marshalErr := json.Marshal(resp); marshalErr == nil { diff --git a/commons/net/http/proxy_validation.go b/commons/net/http/proxy_validation.go index 2cc5f8ed..55ceab41 100644 --- a/commons/net/http/proxy_validation.go +++ b/commons/net/http/proxy_validation.go @@ -2,21 +2,11 @@ package http import ( "net" - "net/netip" "net/url" "strings" -) -var blockedProxyPrefixes = []netip.Prefix{ - netip.MustParsePrefix("0.0.0.0/8"), - netip.MustParsePrefix("100.64.0.0/10"), - netip.MustParsePrefix("192.0.0.0/24"), - netip.MustParsePrefix("192.0.2.0/24"), - netip.MustParsePrefix("198.18.0.0/15"), - netip.MustParsePrefix("198.51.100.0/24"), - netip.MustParsePrefix("203.0.113.0/24"), - netip.MustParsePrefix("240.0.0.0/4"), -} + libSSRF "github.com/LerianStudio/lib-commons/v4/commons/security/ssrf" +) // validateProxyTarget checks a parsed URL against the reverse proxy policy. func validateProxyTarget(targetURL *url.URL, policy ReverseProxyPolicy) error { @@ -79,27 +69,7 @@ func isAllowedHost(host string, allowedHosts []string) bool { } // isUnsafeIP reports whether ip is a loopback, private, or otherwise non-routable address. +// It delegates to the canonical SSRF package for the actual blocked-range check. func isUnsafeIP(ip net.IP) bool { - if ip == nil { - return true - } - - if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() { - return true - } - - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return true - } - - addr = addr.Unmap() - - for _, prefix := range blockedProxyPrefixes { - if prefix.Contains(addr) { - return true - } - } - - return false + return libSSRF.IsBlockedIP(ip) } diff --git a/commons/security/ssrf/doc.go b/commons/security/ssrf/doc.go new file mode 100644 index 00000000..463c2347 --- /dev/null +++ b/commons/security/ssrf/doc.go @@ -0,0 +1,52 @@ +// Package ssrf provides server-side request forgery (SSRF) protection for HTTP +// clients that connect to user-controlled URLs. +// +// # Design Decisions +// +// - Fail-closed: every validation error rejects the request. DNS lookup +// failures, unparseable IPs, and empty resolution results all return an +// error rather than silently allowing the request through. +// +// - DNS pinning: [ResolveAndValidate] performs a single DNS lookup, validates +// all resolved IPs, and rewrites the URL to the first safe IP. This +// eliminates the TOCTOU (time-of-check-to-time-of-use) window that exists +// when validation and connection happen in separate steps; a DNS rebinding +// attack cannot change the record between those two operations. +// +// - Single source of truth: the [BlockedPrefixes] list is the canonical CIDR +// blocklist for the entire lib-commons module. Both the webhook deliverer +// and the reverse-proxy helper delegate to this package instead of +// maintaining their own blocklists. +// +// - Modern types: [netip.Prefix] and [netip.Addr] are the canonical types. +// A legacy [net.IP] entry point ([IsBlockedIP]) is provided for callers +// that have not yet migrated, but it delegates to [IsBlockedAddr] after +// conversion. +// +// # Usage +// +// Quick IP check (no DNS): +// +// addr, _ := netip.ParseAddr("10.0.0.1") +// if ssrf.IsBlockedAddr(addr) { +// // reject +// } +// +// Full URL validation with DNS pinning: +// +// result, err := ssrf.ResolveAndValidate(ctx, "https://example.com/hook") +// if err != nil { +// // reject — err wraps ssrf.ErrBlocked, ssrf.ErrInvalidURL, or ssrf.ErrDNSFailed +// } +// // Use result.PinnedURL for the actual HTTP request. +// // Set the Host header to result.Authority. +// // Set TLS ServerName to result.SNIHostname. +// +// Custom DNS resolver for tests: +// +// result, err := ssrf.ResolveAndValidate(ctx, rawURL, +// ssrf.WithLookupFunc(func(_ context.Context, _ string) ([]string, error) { +// return []string{"93.184.216.34"}, nil +// }), +// ) +package ssrf diff --git a/commons/security/ssrf/hostname.go b/commons/security/ssrf/hostname.go new file mode 100644 index 00000000..86c4468a --- /dev/null +++ b/commons/security/ssrf/hostname.go @@ -0,0 +1,90 @@ +package ssrf + +import "strings" + +// blockedHostnames contains exact hostnames that must be rejected regardless of +// IP resolution. Case-insensitive comparison is used. +// +//nolint:gochecknoglobals // package-level hostname blocklist is intentional for SSRF protection +var blockedHostnames = map[string]bool{ + "localhost": true, + "metadata.google.internal": true, + "metadata.gcp.internal": true, + // AWS metadata IP as a hostname (also caught by IP check — defense-in-depth). + "169.254.169.254": true, +} + +// blockedSuffixes contains hostname suffixes that indicate internal or +// non-routable DNS names. Any hostname ending with one of these suffixes is +// rejected. +// +// Note on ".internal": this blocks cloud metadata endpoints and internal DNS +// names. In corporate environments that use ".internal" as a legitimate TLD, +// use [WithAllowHostname] to exempt specific hosts. +// +//nolint:gochecknoglobals // package-level suffix blocklist is intentional for SSRF protection +var blockedSuffixes = []string{ + ".local", // mDNS / Bonjour (RFC 6762) + ".internal", // cloud metadata, internal DNS + ".cluster.local", // Kubernetes internal DNS +} + +// IsBlockedHostname reports whether hostname matches known dangerous patterns. +// +// Checks performed (case-insensitive): +// - Empty hostname. +// - Exact match against known blocked hostnames (localhost, cloud metadata +// endpoints, AWS metadata IP). +// - Suffix match against dangerous suffixes (.local, .internal, +// .cluster.local). +// +// Note: the ".internal" suffix blocks cloud metadata and internal DNS names but +// may affect legitimate ".internal" domains in corporate environments. Use +// [WithAllowHostname] to exempt specific hostnames when calling +// [ValidateURL] or [ResolveAndValidate]. +func IsBlockedHostname(hostname string) bool { + if hostname == "" { + return true + } + + lower := strings.ToLower(hostname) + + if blockedHostnames[lower] { + return true + } + + for _, suffix := range blockedSuffixes { + if strings.HasSuffix(lower, suffix) { + return true + } + } + + return false +} + +// isBlockedHostnameWithConfig performs the same check as [IsBlockedHostname] +// but respects the allowedHostnames override from functional options. +func isBlockedHostnameWithConfig(hostname string, cfg *config) bool { + if hostname == "" { + return true + } + + lower := strings.ToLower(hostname) + + // Check allow-list override first. + if cfg != nil && cfg.allowedHostnames[lower] { + return false + } + + if blockedHostnames[lower] { + return true + } + + for _, suffix := range blockedSuffixes { + if strings.HasSuffix(lower, suffix) { + return true + } + } + + return false +} diff --git a/commons/security/ssrf/options.go b/commons/security/ssrf/options.go new file mode 100644 index 00000000..d509d4c8 --- /dev/null +++ b/commons/security/ssrf/options.go @@ -0,0 +1,84 @@ +package ssrf + +import ( + "context" + "strings" +) + +// LookupFunc is the signature for a DNS resolver function. It mirrors +// [net.Resolver.LookupHost] and is used by [WithLookupFunc] to inject a custom +// resolver for testing or special environments. +type LookupFunc func(ctx context.Context, host string) ([]string, error) + +// Option configures the behaviour of [ValidateURL] and [ResolveAndValidate]. +type Option func(*config) + +// config holds the resolved configuration built from functional [Option] values. +type config struct { + // httpsOnly rejects URLs whose scheme is not "https". + httpsOnly bool + + // allowPrivate bypasses IP blocking entirely. Intended for local + // development and testing only — never enable in production. + allowPrivate bool + + // lookupFunc overrides the default DNS resolver. When nil, the default + // net.DefaultResolver.LookupHost is used. + lookupFunc LookupFunc + + // allowedHostnames exempts specific hostnames from [IsBlockedHostname] + // checks. Keys are stored lower-cased. + allowedHostnames map[string]bool +} + +// buildConfig applies all options to a default config. +func buildConfig(opts []Option) *config { + cfg := &config{} + + for _, opt := range opts { + opt(cfg) + } + + return cfg +} + +// WithHTTPSOnly rejects any URL whose scheme is not "https". Use this when the +// target must always be contacted over TLS. +func WithHTTPSOnly() Option { + return func(c *config) { + c.httpsOnly = true + } +} + +// WithAllowPrivateNetwork bypasses IP blocking entirely, allowing connections +// to loopback, private, and link-local addresses. This is intended exclusively +// for local development and testing — never enable it in production. +func WithAllowPrivateNetwork() Option { + return func(c *config) { + c.allowPrivate = true + } +} + +// WithLookupFunc sets a custom DNS resolver. This is primarily useful in tests +// to avoid real DNS lookups and to exercise specific resolution scenarios (e.g. +// all IPs blocked, mixed safe/blocked IPs). +func WithLookupFunc(fn LookupFunc) Option { + return func(c *config) { + c.lookupFunc = fn + } +} + +// WithAllowHostname exempts a specific hostname from [IsBlockedHostname] +// checks. The comparison is case-insensitive. This is useful in corporate +// environments where a legitimate service runs on a ".internal" domain. +// +// Multiple calls accumulate — each call adds one hostname to the allow-list. +func WithAllowHostname(hostname string) Option { + return func(c *config) { + if c.allowedHostnames == nil { + c.allowedHostnames = make(map[string]bool) + } + + c.allowedHostnames[strings.ToLower(hostname)] = true + } +} diff --git a/commons/security/ssrf/ssrf.go b/commons/security/ssrf/ssrf.go new file mode 100644 index 00000000..47b35293 --- /dev/null +++ b/commons/security/ssrf/ssrf.go @@ -0,0 +1,110 @@ +package ssrf + +import ( + "errors" + "net" + "net/netip" +) + +// Sentinel errors returned by validation functions. Callers should use +// [errors.Is] to check the category because concrete errors wrap these +// sentinels with additional context via [fmt.Errorf] and %w. +var ( + // ErrBlocked is returned when a URL or IP is rejected by SSRF protection. + ErrBlocked = errors.New("ssrf: blocked") + + // ErrInvalidURL is returned when a URL cannot be parsed or is structurally + // invalid (empty hostname, missing scheme, etc.). + ErrInvalidURL = errors.New("ssrf: invalid URL") + + // ErrDNSFailed is returned when DNS resolution fails for a hostname. + ErrDNSFailed = errors.New("ssrf: DNS resolution failed") +) + +// blockedPrefixes is the canonical CIDR blocklist. It covers RFC-defined +// special-purpose ranges that are not caught by the standard library predicates +// (IsLoopback, IsPrivate, IsLinkLocalUnicast, etc.). +// +// Each entry is intentionally a netip.Prefix literal so that typos are caught +// at init time rather than silently ignored at request time. +// +//nolint:gochecknoglobals // package-level CIDR blocklist is intentional for SSRF protection +var blockedPrefixes = []netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/8"), // "this network" (RFC 1122 S3.2.1.3) + netip.MustParsePrefix("100.64.0.0/10"), // CGNAT (RFC 6598) + netip.MustParsePrefix("192.0.0.0/24"), // IETF protocol assignments (RFC 6890) + netip.MustParsePrefix("192.0.2.0/24"), // TEST-NET-1 documentation (RFC 5737) + netip.MustParsePrefix("198.18.0.0/15"), // benchmarking (RFC 2544) + netip.MustParsePrefix("198.51.100.0/24"), // TEST-NET-2 documentation (RFC 5737) + netip.MustParsePrefix("203.0.113.0/24"), // TEST-NET-3 documentation (RFC 5737) + netip.MustParsePrefix("240.0.0.0/4"), // reserved / future use (RFC 1112) +} + +// BlockedPrefixes returns a copy of the canonical CIDR blocklist. The returned +// slice is safe to modify without affecting the package state. +func BlockedPrefixes() []netip.Prefix { + out := make([]netip.Prefix, len(blockedPrefixes)) + copy(out, blockedPrefixes) + + return out +} + +// IsBlockedAddr reports whether addr falls in a private, loopback, link-local, +// multicast, unspecified, or other reserved range that must not be contacted by +// outbound HTTP requests. +// +// Check order: +// 1. Unmap IPv4-mapped IPv6 addresses (e.g. ::ffff:127.0.0.1 becomes 127.0.0.1). +// 2. Standard library predicates: IsLoopback, IsPrivate, IsLinkLocalUnicast, +// IsLinkLocalMulticast, IsMulticast, IsUnspecified. +// 3. Custom blocklist ([BlockedPrefixes]). +func IsBlockedAddr(addr netip.Addr) bool { + // Step 0: the zero-value netip.Addr{} is not a real IP address. Treating + // it as safe would violate the fail-closed principle — block it. + if !addr.IsValid() { + return true + } + + // Step 1: unmap IPv4-mapped IPv6 so that ::ffff:10.0.0.1 is treated + // identically to 10.0.0.1. + addr = addr.Unmap() + + // Step 2: standard library predicates. + if addr.IsLoopback() || + addr.IsPrivate() || + addr.IsLinkLocalUnicast() || + addr.IsLinkLocalMulticast() || + addr.IsMulticast() || + addr.IsUnspecified() { + return true + } + + // Step 3: custom CIDR blocklist for ranges not covered by the predicates + // above (CGNAT, TEST-NETs, benchmarking, reserved, etc.). + for _, prefix := range blockedPrefixes { + if prefix.Contains(addr) { + return true + } + } + + return false +} + +// IsBlockedIP reports whether ip falls in a blocked range. This is a +// convenience wrapper for callers that still use the legacy [net.IP] type; it +// converts to [netip.Addr] and delegates to [IsBlockedAddr]. +// +// A nil ip is considered blocked (fail-closed). +func IsBlockedIP(ip net.IP) bool { + if ip == nil { + return true + } + + addr, ok := netip.AddrFromSlice(ip) + if !ok { + // Unparseable IP — fail-closed. + return true + } + + return IsBlockedAddr(addr) +} diff --git a/commons/security/ssrf/ssrf_test.go b/commons/security/ssrf/ssrf_test.go new file mode 100644 index 00000000..9b5116f3 --- /dev/null +++ b/commons/security/ssrf/ssrf_test.go @@ -0,0 +1,898 @@ +//go:build unit + +package ssrf + +import ( + "context" + "errors" + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// BlockedPrefixes +// --------------------------------------------------------------------------- + +// expectedPrefixCount lives in the test file so the production code does not +// export or carry a constant that is only meaningful for test assertions. +// Update this value when adding new CIDR ranges to blockedPrefixes. +const expectedPrefixCount = 8 + +func TestBlockedPrefixes_ReturnsExpectedCount(t *testing.T) { + t.Parallel() + + prefixes := BlockedPrefixes() + assert.Len(t, prefixes, expectedPrefixCount, + "BlockedPrefixes() should return exactly %d prefixes", expectedPrefixCount) +} + +func TestBlockedPrefixes_ReturnsCopy(t *testing.T) { + t.Parallel() + + a := BlockedPrefixes() + b := BlockedPrefixes() + + // Mutate the first slice and verify the second is unaffected. + a[0] = netip.MustParsePrefix("255.255.255.255/32") + assert.NotEqual(t, a[0], b[0], + "BlockedPrefixes() must return independent copies") +} + +func TestBlockedPrefixes_ContainsExpectedRanges(t *testing.T) { + t.Parallel() + + expected := []string{ + "0.0.0.0/8", + "100.64.0.0/10", + "192.0.0.0/24", + "192.0.2.0/24", + "198.18.0.0/15", + "198.51.100.0/24", + "203.0.113.0/24", + "240.0.0.0/4", + } + + prefixes := BlockedPrefixes() + strs := make([]string, len(prefixes)) + + for i, p := range prefixes { + strs[i] = p.String() + } + + for _, exp := range expected { + assert.Contains(t, strs, exp, + "BlockedPrefixes() should contain %s", exp) + } +} + +// --------------------------------------------------------------------------- +// IsBlockedAddr — netip.Addr interface +// --------------------------------------------------------------------------- + +func TestIsBlockedAddr(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + addr string + blocked bool + }{ + // --- Loopback --- + {name: "IPv4 loopback 127.0.0.1", addr: "127.0.0.1", blocked: true}, + {name: "IPv4 loopback 127.0.0.2", addr: "127.0.0.2", blocked: true}, + {name: "IPv4 loopback 127.255.255.255", addr: "127.255.255.255", blocked: true}, + {name: "IPv6 loopback ::1", addr: "::1", blocked: true}, + + // --- RFC 1918 private ranges --- + {name: "10.0.0.1 (Class A)", addr: "10.0.0.1", blocked: true}, + {name: "10.255.255.255 (Class A top)", addr: "10.255.255.255", blocked: true}, + {name: "172.16.0.1 (Class B)", addr: "172.16.0.1", blocked: true}, + {name: "172.31.255.255 (Class B top)", addr: "172.31.255.255", blocked: true}, + {name: "192.168.0.1 (Class C)", addr: "192.168.0.1", blocked: true}, + {name: "192.168.255.255 (Class C top)", addr: "192.168.255.255", blocked: true}, + + // --- Link-local --- + {name: "IPv4 link-local 169.254.1.1", addr: "169.254.1.1", blocked: true}, + {name: "IPv4 link-local AWS metadata", addr: "169.254.169.254", blocked: true}, + {name: "IPv6 link-local unicast fe80::1", addr: "fe80::1", blocked: true}, + + // --- Link-local multicast --- + {name: "IPv4 multicast 224.0.0.1", addr: "224.0.0.1", blocked: true}, + {name: "IPv4 multicast 239.255.255.255", addr: "239.255.255.255", blocked: true}, + + // --- Unspecified --- + {name: "IPv4 unspecified 0.0.0.0", addr: "0.0.0.0", blocked: true}, + {name: "IPv6 unspecified ::", addr: "::", blocked: true}, + + // --- CGNAT (RFC 6598): 100.64.0.0/10 --- + {name: "CGNAT low 100.64.0.1", addr: "100.64.0.1", blocked: true}, + {name: "CGNAT mid 100.100.100.100", addr: "100.100.100.100", blocked: true}, + {name: "CGNAT high 100.127.255.254", addr: "100.127.255.254", blocked: true}, + + // --- "this network" (RFC 1122): 0.0.0.0/8 --- + {name: "0.0.0.1 this-network", addr: "0.0.0.1", blocked: true}, + {name: "0.255.255.255 this-network top", addr: "0.255.255.255", blocked: true}, + + // --- IETF protocol assignments (RFC 6890): 192.0.0.0/24 --- + {name: "192.0.0.1 IETF", addr: "192.0.0.1", blocked: true}, + {name: "192.0.0.254 IETF top", addr: "192.0.0.254", blocked: true}, + + // --- TEST-NET-1 (RFC 5737): 192.0.2.0/24 --- + {name: "192.0.2.1 TEST-NET-1", addr: "192.0.2.1", blocked: true}, + + // --- Benchmarking (RFC 2544): 198.18.0.0/15 --- + {name: "198.18.0.1 benchmarking", addr: "198.18.0.1", blocked: true}, + {name: "198.19.255.255 benchmarking top", addr: "198.19.255.255", blocked: true}, + + // --- TEST-NET-2 (RFC 5737): 198.51.100.0/24 --- + {name: "198.51.100.1 TEST-NET-2", addr: "198.51.100.1", blocked: true}, + {name: "198.51.100.10 TEST-NET-2", addr: "198.51.100.10", blocked: true}, + + // --- TEST-NET-3 (RFC 5737): 203.0.113.0/24 --- + {name: "203.0.113.1 TEST-NET-3", addr: "203.0.113.1", blocked: true}, + {name: "203.0.113.10 TEST-NET-3", addr: "203.0.113.10", blocked: true}, + + // --- Reserved (RFC 1112): 240.0.0.0/4 --- + {name: "240.0.0.1 reserved", addr: "240.0.0.1", blocked: true}, + {name: "255.255.255.254 reserved top", addr: "255.255.255.254", blocked: true}, + + // --- IPv4-mapped IPv6 (must unmap and check) --- + {name: "IPv4-mapped loopback ::ffff:127.0.0.1", addr: "::ffff:127.0.0.1", blocked: true}, + {name: "IPv4-mapped private ::ffff:10.0.0.1", addr: "::ffff:10.0.0.1", blocked: true}, + {name: "IPv4-mapped CGNAT ::ffff:100.64.0.1", addr: "::ffff:100.64.0.1", blocked: true}, + {name: "IPv4-mapped TEST-NET-2 ::ffff:198.51.100.10", addr: "::ffff:198.51.100.10", blocked: true}, + + // --- Public IPs — must NOT be blocked --- + {name: "Google DNS 8.8.8.8", addr: "8.8.8.8", blocked: false}, + {name: "Cloudflare 1.1.1.1", addr: "1.1.1.1", blocked: false}, + {name: "Public IPv6 2001:4860:4860::8888", addr: "2001:4860:4860::8888", blocked: false}, + {name: "93.184.216.34 (example.com)", addr: "93.184.216.34", blocked: false}, + + // --- Boundary cases — just outside blocked ranges --- + {name: "100.63.255.255 below CGNAT", addr: "100.63.255.255", blocked: false}, + {name: "100.128.0.1 above CGNAT", addr: "100.128.0.1", blocked: false}, + {name: "172.15.255.255 below Class B", addr: "172.15.255.255", blocked: false}, + {name: "172.32.0.1 above Class B", addr: "172.32.0.1", blocked: false}, + {name: "1.0.0.1 above this-network", addr: "1.0.0.1", blocked: false}, + {name: "198.17.255.255 below benchmarking", addr: "198.17.255.255", blocked: false}, + {name: "198.20.0.1 above benchmarking", addr: "198.20.0.1", blocked: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + addr := netip.MustParseAddr(tt.addr) + assert.Equal(t, tt.blocked, IsBlockedAddr(addr), + "IsBlockedAddr(%s) = %v, want %v", tt.addr, IsBlockedAddr(addr), tt.blocked) + }) + } +} + +func TestIsBlockedAddr_ZeroValue(t *testing.T) { + t.Parallel() + + // The zero-value netip.Addr{} is invalid and must be blocked (fail-closed). + var zero netip.Addr + assert.True(t, IsBlockedAddr(zero), "IsBlockedAddr(netip.Addr{}) must return true (fail-closed)") +} + +// --------------------------------------------------------------------------- +// IsBlockedIP — legacy net.IP interface +// --------------------------------------------------------------------------- + +func TestIsBlockedIP(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ip string + blocked bool + }{ + {name: "IPv4 loopback", ip: "127.0.0.1", blocked: true}, + {name: "IPv6 loopback", ip: "::1", blocked: true}, + {name: "RFC 1918 10.x", ip: "10.0.0.1", blocked: true}, + {name: "RFC 1918 172.16.x", ip: "172.16.0.1", blocked: true}, + {name: "RFC 1918 192.168.x", ip: "192.168.0.1", blocked: true}, + {name: "Link-local", ip: "169.254.1.1", blocked: true}, + {name: "CGNAT", ip: "100.64.0.1", blocked: true}, + {name: "TEST-NET-1", ip: "192.0.2.1", blocked: true}, + {name: "Multicast", ip: "224.0.0.1", blocked: true}, + {name: "Google DNS", ip: "8.8.8.8", blocked: false}, + {name: "Cloudflare", ip: "1.1.1.1", blocked: false}, + {name: "Public IPv6", ip: "2001:4860:4860::8888", blocked: false}, + // IPv4-mapped IPv6 + {name: "IPv4-mapped loopback", ip: "::ffff:127.0.0.1", blocked: true}, + {name: "IPv4-mapped private", ip: "::ffff:10.0.0.1", blocked: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ip := net.ParseIP(tt.ip) + require.NotNil(t, ip, "failed to parse IP: %s", tt.ip) + assert.Equal(t, tt.blocked, IsBlockedIP(ip), + "IsBlockedIP(%s) = %v, want %v", tt.ip, IsBlockedIP(ip), tt.blocked) + }) + } +} + +func TestIsBlockedIP_NilIP(t *testing.T) { + t.Parallel() + + assert.True(t, IsBlockedIP(nil), "IsBlockedIP(nil) must return true (fail-closed)") +} + +// --------------------------------------------------------------------------- +// IsBlockedHostname +// --------------------------------------------------------------------------- + +func TestIsBlockedHostname(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + host string + blocked bool + }{ + // Exact blocked hostnames + {name: "localhost", host: "localhost", blocked: true}, + {name: "LOCALHOST uppercase", host: "LOCALHOST", blocked: true}, + {name: "Localhost mixed", host: "Localhost", blocked: true}, + {name: "metadata.google.internal", host: "metadata.google.internal", blocked: true}, + {name: "metadata.gcp.internal", host: "metadata.gcp.internal", blocked: true}, + {name: "AWS metadata IP", host: "169.254.169.254", blocked: true}, + + // Dangerous suffixes + {name: ".local suffix", host: "myhost.local", blocked: true}, + {name: ".internal suffix", host: "service.internal", blocked: true}, + {name: ".cluster.local suffix", host: "api.default.svc.cluster.local", blocked: true}, + {name: ".INTERNAL uppercase", host: "service.INTERNAL", blocked: true}, + + // Empty hostname + {name: "empty string", host: "", blocked: true}, + + // Legitimate hostnames — must NOT be blocked + {name: "example.com", host: "example.com", blocked: false}, + {name: "api.example.com", host: "api.example.com", blocked: false}, + {name: "hooks.stripe.com", host: "hooks.stripe.com", blocked: false}, + {name: "8.8.8.8 public", host: "8.8.8.8", blocked: false}, + {name: "2001:4860:4860::8888", host: "2001:4860:4860::8888", blocked: false}, + + // Tricky near-misses that must NOT be blocked + {name: "localhostx not blocked", host: "localhostx", blocked: false}, + {name: "mylocal not blocked", host: "mylocal", blocked: false}, + {name: "internal.example.com not blocked", host: "internal.example.com", blocked: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.blocked, IsBlockedHostname(tt.host), + "IsBlockedHostname(%q) = %v, want %v", tt.host, IsBlockedHostname(tt.host), tt.blocked) + }) + } +} + +// --------------------------------------------------------------------------- +// ValidateURL +// --------------------------------------------------------------------------- + +func TestValidateURL_ValidPublicURLs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "https example.com", url: "https://example.com/hook"}, + {name: "http example.com", url: "http://example.com/hook"}, + {name: "https with port", url: "https://example.com:8443/hook"}, + {name: "https with path and query", url: "https://example.com/a/b?c=d"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), tt.url) + assert.NoError(t, err) + }) + } +} + +func TestValidateURL_BlockedSchemes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "gopher", url: "gopher://evil.com"}, + {name: "file", url: "file:///etc/passwd"}, + {name: "ftp", url: "ftp://example.com/file"}, + {name: "javascript", url: "javascript:alert(1)"}, + {name: "data", url: "data:text/html,

hi

"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), tt.url) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + } +} + +func TestValidateURL_BlockedHostnames(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "localhost", url: "http://localhost/hook"}, + {name: "metadata.google.internal", url: "https://metadata.google.internal/computeMetadata"}, + {name: ".local suffix", url: "http://printer.local/status"}, + {name: "k8s internal", url: "http://api.default.svc.cluster.local/health"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), tt.url) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + } +} + +func TestValidateURL_BlockedIPLiteral(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "loopback", url: "http://127.0.0.1/hook"}, + {name: "private 10.x", url: "https://10.0.0.1/hook"}, + {name: "private 192.168.x", url: "https://192.168.1.1/hook"}, + {name: "CGNAT", url: "http://100.64.0.1/hook"}, + {name: "IPv6 loopback", url: "http://[::1]/hook"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), tt.url) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + } +} + +func TestValidateURL_EmptyHostname(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "http://") + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) +} + +func TestValidateURL_MalformedURL(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "://missing-scheme") + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) +} + +func TestValidateURL_HTTPSOnly(t *testing.T) { + t.Parallel() + + t.Run("https allowed", func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "https://example.com/hook", WithHTTPSOnly()) + assert.NoError(t, err) + }) + + t.Run("http rejected", func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "http://example.com/hook", WithHTTPSOnly()) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + assert.Contains(t, err.Error(), "HTTPS only") + }) +} + +func TestValidateURL_AllowPrivateNetwork(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "loopback IP", url: "http://127.0.0.1/hook"}, + {name: "private IP", url: "http://10.0.0.1/hook"}, + {name: "CGNAT IP", url: "http://100.64.0.1/hook"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), tt.url, WithAllowPrivateNetwork()) + // AllowPrivateNetwork bypasses IP blocking but hostname blocking + // is still evaluated. Since these are IP literals (not "localhost"), + // they should pass. + assert.NoError(t, err) + }) + } +} + +func TestValidateURL_AllowHostname(t *testing.T) { + t.Parallel() + + t.Run("exempted hostname passes", func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "http://service.internal/hook", + WithAllowHostname("service.internal")) + assert.NoError(t, err) + }) + + t.Run("non-exempted hostname still blocked", func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "http://other.internal/hook", + WithAllowHostname("service.internal")) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + + t.Run("case insensitive", func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "http://Service.Internal/hook", + WithAllowHostname("service.internal")) + assert.NoError(t, err) + }) +} + +// --------------------------------------------------------------------------- +// ResolveAndValidate +// --------------------------------------------------------------------------- + +// fakeLookup returns a LookupFunc that yields the given IPs. +func fakeLookup(ips ...string) LookupFunc { + return func(_ context.Context, _ string) ([]string, error) { + return ips, nil + } +} + +// failLookup returns a LookupFunc that always fails. +func failLookup() LookupFunc { + return func(_ context.Context, _ string) ([]string, error) { + return nil, errors.New("simulated DNS failure") + } +} + +func TestResolveAndValidate_PublicIP(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), + "https://example.com:8443/hook", + WithLookupFunc(fakeLookup("93.184.216.34")), + ) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "https://93.184.216.34:8443/hook", result.PinnedURL) + assert.Equal(t, "example.com:8443", result.Authority) + assert.Equal(t, "example.com", result.SNIHostname) +} + +func TestResolveAndValidate_PublicIPNoPort(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("93.184.216.34")), + ) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "https://93.184.216.34/hook", result.PinnedURL) + assert.Equal(t, "example.com", result.Authority) + assert.Equal(t, "example.com", result.SNIHostname) +} + +func TestResolveAndValidate_IPv6Public(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("2001:4860:4860::8888")), + ) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "https://[2001:4860:4860::8888]/hook", result.PinnedURL) + assert.Equal(t, "example.com", result.Authority) + assert.Equal(t, "example.com", result.SNIHostname) +} + +func TestResolveAndValidate_BlockedIP(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ip string + }{ + {name: "loopback", ip: "127.0.0.1"}, + {name: "private 10.x", ip: "10.0.0.1"}, + {name: "private 192.168.x", ip: "192.168.1.1"}, + {name: "CGNAT", ip: "100.64.0.1"}, + {name: "link-local", ip: "169.254.169.254"}, + {name: "TEST-NET-1", ip: "192.0.2.1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup(tt.ip)), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + } +} + +func TestResolveAndValidate_DNSFailure(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(failLookup()), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrDNSFailed) +} + +func TestResolveAndValidate_DNSEmptyResult(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup()), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrDNSFailed) +} + +func TestResolveAndValidate_BlockedScheme(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "gopher", url: "gopher://evil.com"}, + {name: "file", url: "file:///etc/passwd"}, + {name: "ftp", url: "ftp://example.com/file"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), tt.url, + WithLookupFunc(fakeLookup("93.184.216.34")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + } +} + +func TestResolveAndValidate_BlockedHostname(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "localhost", url: "http://localhost/hook"}, + {name: "metadata.google.internal", url: "https://metadata.google.internal/computeMetadata"}, + {name: ".local suffix", url: "http://printer.local/status"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), tt.url, + WithLookupFunc(fakeLookup("93.184.216.34")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + } +} + +func TestResolveAndValidate_EmptyHostname(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), "http://", + WithLookupFunc(fakeLookup("93.184.216.34")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) +} + +func TestResolveAndValidate_MalformedURL(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), "://no-scheme", + WithLookupFunc(fakeLookup("93.184.216.34")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) +} + +func TestResolveAndValidate_HTTPSOnly(t *testing.T) { + t.Parallel() + + t.Run("https passes", func(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("93.184.216.34")), + WithHTTPSOnly(), + ) + + require.NoError(t, err) + assert.Contains(t, result.PinnedURL, "https://") + }) + + t.Run("http rejected", func(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), + "http://example.com/hook", + WithLookupFunc(fakeLookup("93.184.216.34")), + WithHTTPSOnly(), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) +} + +func TestResolveAndValidate_AllowPrivateNetwork(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), + "http://example.com/hook", + WithLookupFunc(fakeLookup("10.0.0.1")), + WithAllowPrivateNetwork(), + ) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "http://10.0.0.1/hook", result.PinnedURL) +} + +func TestResolveAndValidate_AllowHostname(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), + "http://service.internal/hook", + WithLookupFunc(fakeLookup("93.184.216.34")), + WithAllowHostname("service.internal"), + ) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "http://93.184.216.34/hook", result.PinnedURL) + assert.Equal(t, "service.internal", result.SNIHostname) +} + +func TestResolveAndValidate_MultipleIPs_FirstSafe(t *testing.T) { + t.Parallel() + + // First IP is safe — should be selected for pinning. + result, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("93.184.216.34", "1.1.1.1")), + ) + + require.NoError(t, err) + assert.Equal(t, "https://93.184.216.34/hook", result.PinnedURL) +} + +func TestResolveAndValidate_MultipleIPs_BlockedAmongSafe(t *testing.T) { + t.Parallel() + + // First IP is blocked — fail-closed: reject the entire request. + _, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("10.0.0.1", "93.184.216.34")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) +} + +func TestResolveAndValidate_AllIPsBlocked(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("127.0.0.1", "10.0.0.1", "192.168.1.1")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) +} + +func TestResolveAndValidate_UnparseableIPSkipped(t *testing.T) { + t.Parallel() + + // DNS returns an unparseable entry followed by a valid one. + result, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("not-an-ip", "93.184.216.34")), + ) + + require.NoError(t, err) + assert.Equal(t, "https://93.184.216.34/hook", result.PinnedURL) +} + +func TestResolveAndValidate_AllUnparseable(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("not-an-ip", "also-not-an-ip")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) +} + +// --------------------------------------------------------------------------- +// Options composition +// --------------------------------------------------------------------------- + +func TestOptions_MultipleAllowHostname(t *testing.T) { + t.Parallel() + + opts := []Option{ + WithAllowHostname("a.internal"), + WithAllowHostname("b.internal"), + WithLookupFunc(fakeLookup("93.184.216.34")), + } + + t.Run("first allowed", func(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), "http://a.internal/hook", opts...) + require.NoError(t, err) + assert.Equal(t, "a.internal", result.SNIHostname) + }) + + t.Run("second allowed", func(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), "http://b.internal/hook", opts...) + require.NoError(t, err) + assert.Equal(t, "b.internal", result.SNIHostname) + }) + + t.Run("third still blocked", func(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), "http://c.internal/hook", opts...) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) +} + +// --------------------------------------------------------------------------- +// validateScheme (internal, tested via ValidateURL/ResolveAndValidate above, +// but explicit coverage for edge cases) +// --------------------------------------------------------------------------- + +func TestValidateScheme_CaseInsensitive(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + scheme string + }{ + {name: "HTTP uppercase", scheme: "HTTP"}, + {name: "HTTPS uppercase", scheme: "HTTPS"}, + {name: "Mixed Http", scheme: "Http"}, + {name: "Mixed hTtPs", scheme: "hTtPs"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := validateScheme(tt.scheme, &config{}) + assert.NoError(t, err, "scheme %q must be allowed", tt.scheme) + }) + } +} + +func TestValidateScheme_EmptyScheme(t *testing.T) { + t.Parallel() + + err := validateScheme("", &config{}) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) +} + +// --------------------------------------------------------------------------- +// isBlockedHostnameWithConfig +// --------------------------------------------------------------------------- + +func TestIsBlockedHostnameWithConfig_NilConfig(t *testing.T) { + t.Parallel() + + // nil config should behave like IsBlockedHostname. + assert.True(t, isBlockedHostnameWithConfig("localhost", nil)) + assert.False(t, isBlockedHostnameWithConfig("example.com", nil)) +} + +func TestIsBlockedHostnameWithConfig_AllowOverride(t *testing.T) { + t.Parallel() + + cfg := &config{ + allowedHostnames: map[string]bool{ + "special.internal": true, + }, + } + + assert.False(t, isBlockedHostnameWithConfig("special.internal", cfg), + "allowed hostname should not be blocked") + assert.True(t, isBlockedHostnameWithConfig("other.internal", cfg), + "non-allowed .internal hostname should still be blocked") +} + +// --------------------------------------------------------------------------- +// Error sentinel identity +// --------------------------------------------------------------------------- + +func TestSentinelErrors_AreDistinct(t *testing.T) { + t.Parallel() + + assert.NotErrorIs(t, ErrBlocked, ErrInvalidURL) + assert.NotErrorIs(t, ErrBlocked, ErrDNSFailed) + assert.NotErrorIs(t, ErrInvalidURL, ErrDNSFailed) +} diff --git a/commons/security/ssrf/validate.go b/commons/security/ssrf/validate.go new file mode 100644 index 00000000..5ac1e888 --- /dev/null +++ b/commons/security/ssrf/validate.go @@ -0,0 +1,200 @@ +package ssrf + +import ( + "context" + "fmt" + "net" + "net/netip" + "net/url" + "strings" +) + +// ResolveResult holds the output of [ResolveAndValidate]. Callers should use +// [PinnedURL] for the actual HTTP request, set the Host header to [Authority], +// and configure TLS ServerName to [SNIHostname]. +type ResolveResult struct { + // PinnedURL is the original URL with the hostname replaced by the first + // safe resolved IP address. Using this URL for the HTTP request eliminates + // the TOCTOU window between DNS validation and connection. + PinnedURL string + + // Authority is the original host:port value (url.URL.Host) before DNS + // pinning. It preserves explicit non-default ports and should be used as + // the HTTP Host header value so the target server routes correctly. + Authority string + + // SNIHostname is the bare hostname without port (url.URL.Hostname()). It + // should be used for TLS SNI / certificate verification. + SNIHostname string +} + +// ValidateURL checks a URL for SSRF safety without performing DNS resolution. +// It validates the scheme, hostname blocking, and any IP literal in the +// hostname. +// +// Use [ResolveAndValidate] when DNS pinning is needed (i.e. when you intend to +// actually connect to the URL). ValidateURL is suitable for pre-flight +// validation where DNS resolution is deferred or not desired. +// +// Errors wrap [ErrBlocked] or [ErrInvalidURL] for programmatic inspection via +// [errors.Is]. +func ValidateURL(ctx context.Context, rawURL string, opts ...Option) error { + _ = ctx // TODO(ssrf): use ctx for tracing/metrics integration when telemetry is wired in + + cfg := buildConfig(opts) + + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("%w: %w", ErrInvalidURL, err) + } + + if err := validateScheme(u.Scheme, cfg); err != nil { + return err + } + + hostname := u.Hostname() + if hostname == "" { + return fmt.Errorf("%w: empty hostname", ErrInvalidURL) + } + + // Check hostname-level blocking (localhost, metadata endpoints, etc.). + if isBlockedHostnameWithConfig(hostname, cfg) { + return fmt.Errorf("%w: hostname %q is blocked", ErrBlocked, hostname) + } + + // If the hostname is an IP literal, validate it against the CIDR blocklist. + if !cfg.allowPrivate { + if addr, err := netip.ParseAddr(hostname); err == nil { + if IsBlockedAddr(addr) { + return fmt.Errorf("%w: IP %s is in a blocked range", ErrBlocked, hostname) + } + } + } + + return nil +} + +// ResolveAndValidate performs DNS resolution, validates all resolved IPs +// against the SSRF blocklist, and returns a pinned URL. This eliminates the +// TOCTOU window between "validate" and "connect" by combining both into a +// single DNS lookup. +// +// Flow: +// 1. Parse URL, validate scheme (http/https only, or https-only with [WithHTTPSOnly]). +// 2. Check hostname blocking ([IsBlockedHostname]). +// 3. Resolve DNS (single lookup via [net.DefaultResolver] or [WithLookupFunc]). +// 4. Validate resolved IPs — reject if ANY IP is blocked ([IsBlockedAddr]). +// 5. Pin URL to first safe IP, return [ResolveResult]. +// +// DNS lookup failures are fail-closed: if the hostname cannot be resolved the +// URL is rejected. When every resolved IP is blocked the URL is also rejected. +// +// Errors wrap [ErrBlocked], [ErrInvalidURL], or [ErrDNSFailed] for +// programmatic inspection via [errors.Is]. +func ResolveAndValidate(ctx context.Context, rawURL string, opts ...Option) (*ResolveResult, error) { + cfg := buildConfig(opts) + + u, err := url.Parse(rawURL) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrInvalidURL, err) + } + + if err := validateScheme(u.Scheme, cfg); err != nil { + return nil, err + } + + hostname := u.Hostname() + if hostname == "" { + return nil, fmt.Errorf("%w: empty hostname", ErrInvalidURL) + } + + // Hostname-level blocking (localhost, metadata endpoints, dangerous suffixes). + if isBlockedHostnameWithConfig(hostname, cfg) { + return nil, fmt.Errorf("%w: hostname %q is blocked", ErrBlocked, hostname) + } + + // DNS resolution — use custom resolver if provided, otherwise default. + ips, dnsErr := lookupHost(ctx, hostname, cfg) + if dnsErr != nil { + return nil, fmt.Errorf("%w: lookup failed for %s: %w", ErrDNSFailed, hostname, dnsErr) + } + + if len(ips) == 0 { + return nil, fmt.Errorf("%w: no addresses returned for %s", ErrDNSFailed, hostname) + } + + // Validate every resolved IP and find the first safe one. + var firstSafeIP string + + for _, ipStr := range ips { + addr, parseErr := netip.ParseAddr(ipStr) + if parseErr != nil { + // Skip unparseable entries; if none survive we fail below. + continue + } + + if !cfg.allowPrivate && IsBlockedAddr(addr) { + return nil, fmt.Errorf("%w: resolved IP %s is in a blocked range", ErrBlocked, ipStr) + } + + if firstSafeIP == "" { + firstSafeIP = ipStr + } + } + + if firstSafeIP == "" { + return nil, fmt.Errorf("%w: no valid IPs resolved for %s", ErrInvalidURL, hostname) + } + + // Preserve the original authority before rewriting the host. + authority := u.Host + + // Pin to first safe IP — prevents DNS rebinding across retries. + port := u.Port() + + switch { + case port != "": + u.Host = net.JoinHostPort(firstSafeIP, port) + case strings.Contains(firstSafeIP, ":"): + // Bare IPv6 literal must be bracket-wrapped for url.URL.Host. + u.Host = "[" + firstSafeIP + "]" + default: + u.Host = firstSafeIP + } + + return &ResolveResult{ + PinnedURL: u.String(), + Authority: authority, + SNIHostname: hostname, + }, nil +} + +// validateScheme checks that the URL scheme is allowed. By default only "http" +// and "https" are permitted. With [WithHTTPSOnly], only "https" is allowed. +func validateScheme(scheme string, cfg *config) error { + s := strings.ToLower(scheme) + + if cfg.httpsOnly { + if s != "https" { + return fmt.Errorf("%w: scheme %q not allowed (HTTPS only)", ErrBlocked, scheme) + } + + return nil + } + + if s != "http" && s != "https" { + return fmt.Errorf("%w: scheme %q not allowed", ErrBlocked, scheme) + } + + return nil +} + +// lookupHost resolves hostname via the configured lookup function or the +// default resolver. +func lookupHost(ctx context.Context, hostname string, cfg *config) ([]string, error) { + if cfg.lookupFunc != nil { + return cfg.lookupFunc(ctx, hostname) + } + + return net.DefaultResolver.LookupHost(ctx, hostname) +} diff --git a/commons/systemplane/bootstrap/backend_test.go b/commons/systemplane/bootstrap/backend_test.go index 05b907a3..8b3b2430 100644 --- a/commons/systemplane/bootstrap/backend_test.go +++ b/commons/systemplane/bootstrap/backend_test.go @@ -408,7 +408,7 @@ func TestResetBackendFactories_ClearsRegistrations(t *testing.T) { RecordInitError(fmt.Errorf("simulated init error")) factories, initErrors := backendRegistry.snapshot() - assert.Len(t, factories, 1) + assert.Contains(t, factories, domain.BackendPostgres, "expected the registered factory to be present") assert.NotEmpty(t, initErrors) // Reset and verify. diff --git a/commons/systemplane/catalog/validate.go b/commons/systemplane/catalog/validate.go index 9caeb127..8aca53dc 100644 --- a/commons/systemplane/catalog/validate.go +++ b/commons/systemplane/catalog/validate.go @@ -32,6 +32,17 @@ func (m Mismatch) String() string { m.CatalogKey, m.Field, m.CatalogValue, m.ProductValue) } +// ValidationResult holds the outcome of ValidateKeyDefsWithOptions, separating +// product-vs-catalog mismatches from catalog-internal env var conflicts. +type ValidationResult struct { + // Mismatches contains divergences between product KeyDefs and catalog SharedKeys. + Mismatches []Mismatch + + // EnvVarConflicts contains catalog-internal duplicate env var mappings + // where two SharedKeys map the same environment variable. + EnvVarConflicts []EnvVarConflict +} + // ValidateOption configures the behavior of ValidateKeyDefs. type ValidateOption func(*validateConfig) @@ -108,14 +119,19 @@ func (vc *validateConfig) shouldSkip(catalogKey, field string) bool { // ValidateKeyDefsWithOptions which accepts ValidateOption values. // ValidateKeyDefs retains its original signature for backward compatibility. func ValidateKeyDefs(productDefs []domain.KeyDef, catalogKeys ...[]SharedKey) []Mismatch { - return ValidateKeyDefsWithOptions(productDefs, catalogKeys, nil) + result := ValidateKeyDefsWithOptions(productDefs, catalogKeys, nil) + return result.Mismatches } // ValidateKeyDefsWithOptions is the full-featured variant of ValidateKeyDefs // that accepts filtering options. Use WithIgnoreFields and WithKnownDeviation // to suppress expected mismatches in product catalog tests. -func ValidateKeyDefsWithOptions(productDefs []domain.KeyDef, catalogKeys [][]SharedKey, opts []ValidateOption) []Mismatch { - keyIndex, envIndex := buildCatalogIndexes(catalogKeys...) +// +// Returns a ValidationResult separating product-vs-catalog mismatches from +// catalog-internal env var conflicts. Use result.Mismatches for product drift +// and result.EnvVarConflicts for catalog-internal duplicate env var mappings. +func ValidateKeyDefsWithOptions(productDefs []domain.KeyDef, catalogKeys [][]SharedKey, opts []ValidateOption) ValidationResult { + keyIndex, envIndex, envConflicts := buildCatalogIndexes(catalogKeys...) vc := newValidateConfig(opts) var mismatches []Mismatch @@ -144,23 +160,77 @@ func ValidateKeyDefsWithOptions(productDefs []domain.KeyDef, catalogKeys [][]Sha return mismatches[i].Field < mismatches[j].Field }) - return mismatches + sort.Slice(envConflicts, func(i, j int) bool { + if envConflicts[i].EnvVar != envConflicts[j].EnvVar { + return envConflicts[i].EnvVar < envConflicts[j].EnvVar + } + + return envConflicts[i].ExistingKey < envConflicts[j].ExistingKey + }) + + return ValidationResult{ + Mismatches: mismatches, + EnvVarConflicts: envConflicts, + } } -func buildCatalogIndexes(catalogKeys ...[]SharedKey) (map[string]SharedKey, map[string]SharedKey) { +func buildCatalogIndexes(catalogKeys ...[]SharedKey) (map[string]SharedKey, map[string]SharedKey, []EnvVarConflict) { keyIndex := make(map[string]SharedKey) envIndex := make(map[string]SharedKey) + ambiguousEnvVars := make(map[string]bool) + + var conflicts []EnvVarConflict for _, slice := range catalogKeys { for _, sk := range slice { keyIndex[sk.Key] = sk for _, envVar := range allowedEnvVars(sk) { + if ambiguousEnvVars[envVar] { + // Already known ambiguous from a prior conflict; record + // this additional claimant but do not re-add to envIndex. + conflicts = append(conflicts, EnvVarConflict{ + EnvVar: envVar, + ExistingKey: "(ambiguous)", + ConflictKey: sk.Key, + }) + + continue + } + + if existing, duplicate := envIndex[envVar]; duplicate { + conflicts = append(conflicts, EnvVarConflict{ + EnvVar: envVar, + ExistingKey: existing.Key, + ConflictKey: sk.Key, + }) + + // Mark as ambiguous and remove from envIndex so + // resolveSharedKey will not match this env var. + ambiguousEnvVars[envVar] = true + delete(envIndex, envVar) + + continue + } + envIndex[envVar] = sk } } } - return keyIndex, envIndex + return keyIndex, envIndex, conflicts +} + +// EnvVarConflict describes a duplicate environment variable mapping detected +// during catalog index construction. Two SharedKey entries map the same env var. +type EnvVarConflict struct { + EnvVar string // the duplicated environment variable + ExistingKey string // the SharedKey.Key that was registered first + ConflictKey string // the SharedKey.Key that attempted to overwrite +} + +// String returns a human-readable description of the conflict. +func (c EnvVarConflict) String() string { + return fmt.Sprintf("duplicate env var %q: mapped by both %q and %q", c.EnvVar, c.ExistingKey, c.ConflictKey) } func resolveSharedKey(pd domain.KeyDef, keyIndex map[string]SharedKey, envIndex map[string]SharedKey) (SharedKey, bool, bool) { diff --git a/commons/systemplane/catalog/validate_options_test.go b/commons/systemplane/catalog/validate_options_test.go index 561ea412..927766e5 100644 --- a/commons/systemplane/catalog/validate_options_test.go +++ b/commons/systemplane/catalog/validate_options_test.go @@ -30,9 +30,9 @@ func TestValidateKeyDefsWithOptions_WithIgnoreFields(t *testing.T) { require.NotEmpty(t, envVarMismatches, "expected EnvVar mismatch without options") // With WithIgnoreFields("EnvVar"): should suppress the EnvVar mismatch. - filtered := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, + result := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, []ValidateOption{WithIgnoreFields("EnvVar")}) - envVarMismatches = filterByField(filtered, "EnvVar") + envVarMismatches = filterByField(result.Mismatches, "EnvVar") assert.Empty(t, envVarMismatches, "EnvVar mismatches should be suppressed by WithIgnoreFields") } @@ -54,9 +54,9 @@ func TestValidateKeyDefsWithOptions_WithKnownDeviation(t *testing.T) { require.NotEmpty(t, componentMismatches, "expected Component mismatch without options") // With WithKnownDeviation: should suppress only that key+field. - filtered := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, + result := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, []ValidateOption{WithKnownDeviation("app.log_level", "Component")}) - componentMismatches = filterByField(filtered, "Component") + componentMismatches = filterByField(result.Mismatches, "Component") assert.Empty(t, componentMismatches, "Component mismatch should be suppressed by WithKnownDeviation") } @@ -76,12 +76,12 @@ func TestValidateKeyDefsWithOptions_KnownDeviationDoesNotSuppressOtherKeys(t *te }, } - filtered := ValidateKeyDefsWithOptions(productDefs, + result := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys(), CORSKeys()}, []ValidateOption{WithKnownDeviation("app.log_level", "Component")}) // app.log_level Component should be suppressed. - for _, mm := range filtered { + for _, mm := range result.Mismatches { if mm.CatalogKey == "app.log_level" && mm.Field == "Component" { t.Error("app.log_level Component deviation should have been suppressed") } @@ -89,7 +89,7 @@ func TestValidateKeyDefsWithOptions_KnownDeviationDoesNotSuppressOtherKeys(t *te // cors.allowed_origins Component should NOT be suppressed. corsComponentFound := false - for _, mm := range filtered { + for _, mm := range result.Mismatches { if mm.CatalogKey == "cors.allowed_origins" && mm.Field == "Component" { corsComponentFound = true } @@ -111,7 +111,7 @@ func TestValidateKeyDefsWithOptions_NilOptions(t *testing.T) { // Nil options should work the same as no options. result := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, nil) - assert.Empty(t, result, "matching key with nil options should produce no mismatches") + assert.Empty(t, result.Mismatches, "matching key with nil options should produce no mismatches") } func TestValidateKeyDefsWithOptions_CombinedOptions(t *testing.T) { @@ -126,13 +126,13 @@ func TestValidateKeyDefsWithOptions_CombinedOptions(t *testing.T) { }, } - filtered := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, + result := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, []ValidateOption{ WithIgnoreFields("EnvVar"), WithKnownDeviation("app.log_level", "Component"), }) - assert.Empty(t, filtered, "all mismatches should be suppressed by combined options") + assert.Empty(t, result.Mismatches, "all mismatches should be suppressed by combined options") } func filterByField(mismatches []Mismatch, field string) []Mismatch { diff --git a/commons/systemplane/domain/setting_helpers.go b/commons/systemplane/domain/setting_helpers.go index 62c10573..43f5730b 100644 --- a/commons/systemplane/domain/setting_helpers.go +++ b/commons/systemplane/domain/setting_helpers.go @@ -9,13 +9,13 @@ func SnapSettingString(snap *Snapshot, tenantID, key string, fallback string) st return fallback } - if raw, ok := snap.GetTenantSetting(tenantID, key); ok { + if raw, ok := snap.GetTenantSetting(tenantID, key); ok && raw.Value != nil { if value, converted := tryCoerceString(raw.Value); converted { return value } } - if raw, ok := snap.GetGlobalSetting(key); ok { + if raw, ok := snap.GetGlobalSetting(key); ok && raw.Value != nil { if value, converted := tryCoerceString(raw.Value); converted { return value } @@ -31,13 +31,13 @@ func SnapSettingInt(snap *Snapshot, tenantID, key string, fallback int) int { return fallback } - if raw, ok := snap.GetTenantSetting(tenantID, key); ok { + if raw, ok := snap.GetTenantSetting(tenantID, key); ok && raw.Value != nil { if value, converted := tryCoerceInt(raw.Value); converted { return value } } - if raw, ok := snap.GetGlobalSetting(key); ok { + if raw, ok := snap.GetGlobalSetting(key); ok && raw.Value != nil { if value, converted := tryCoerceInt(raw.Value); converted { return value } @@ -53,13 +53,13 @@ func SnapSettingBool(snap *Snapshot, tenantID, key string, fallback bool) bool { return fallback } - if raw, ok := snap.GetTenantSetting(tenantID, key); ok { + if raw, ok := snap.GetTenantSetting(tenantID, key); ok && raw.Value != nil { if value, converted := tryCoerceBool(raw.Value); converted { return value } } - if raw, ok := snap.GetGlobalSetting(key); ok { + if raw, ok := snap.GetGlobalSetting(key); ok && raw.Value != nil { if value, converted := tryCoerceBool(raw.Value); converted { return value } diff --git a/commons/systemplane/service/supervisor.go b/commons/systemplane/service/supervisor.go index 5c07434e..cf5fa844 100644 --- a/commons/systemplane/service/supervisor.go +++ b/commons/systemplane/service/supervisor.go @@ -212,14 +212,16 @@ func (supervisor *defaultSupervisor) Reload(ctx context.Context, reason string, supervisor.mu.Lock() defer supervisor.mu.Unlock() + currentState := supervisor.state.Load() + var prevSnap *domain.Snapshot - if st := supervisor.state.Load(); st != nil { - prevSnap = &st.snapshot + if currentState != nil { + prevSnap = ¤tState.snapshot } tenantIDs := mergeUniqueTenantIDs(cachedTenantIDs(prevSnap), extraTenantIDs) - build, err := supervisor.prepareReloadBuild(ctx, tenantIDs) + build, err := supervisor.prepareReloadBuild(ctx, tenantIDs, currentState) if err != nil { libOpentelemetry.HandleSpanError(span, "build runtime bundle", err) return err diff --git a/commons/systemplane/service/supervisor_helpers.go b/commons/systemplane/service/supervisor_helpers.go index 32df7b5f..e1e170eb 100644 --- a/commons/systemplane/service/supervisor_helpers.go +++ b/commons/systemplane/service/supervisor_helpers.go @@ -147,7 +147,7 @@ func (supervisor *defaultSupervisor) buildBundle( // AdoptResourcesFrom, which runs AFTER the atomic state swap. At that point // Current() already returns the new candidate, so concurrent readers never // observe the mutation. -func (supervisor *defaultSupervisor) prepareReloadBuild(ctx context.Context, tenantIDs []string) (reloadBuild, error) { +func (supervisor *defaultSupervisor) prepareReloadBuild(ctx context.Context, tenantIDs []string, currentState *supervisorState) (reloadBuild, error) { snap, err := supervisor.builder.BuildFull(ctx, tenantIDs...) if err != nil { return reloadBuild{}, fmt.Errorf("reload: %w: %w", domain.ErrSnapshotBuildFailed, err) @@ -157,9 +157,9 @@ func (supervisor *defaultSupervisor) prepareReloadBuild(ctx context.Context, ten var previousBundle domain.RuntimeBundle - if st := supervisor.state.Load(); st != nil { - prevSnap = &st.snapshot - previousBundle = st.bundle + if currentState != nil { + prevSnap = ¤tState.snapshot + previousBundle = currentState.bundle } candidate, strategy, err := supervisor.buildBundle(ctx, snap, previousBundle, prevSnap) diff --git a/commons/webhook/deliverer.go b/commons/webhook/deliverer.go index bb2374df..2b9a9055 100644 --- a/commons/webhook/deliverer.go +++ b/commons/webhook/deliverer.go @@ -26,6 +26,21 @@ import ( "github.com/LerianStudio/lib-commons/v4/commons/runtime" ) +// SignatureVersion controls the HMAC signing format used for X-Webhook-Signature. +type SignatureVersion int + +const ( + // SignatureV0 produces legacy payload-only signatures: "sha256=". + // This is the default for backward compatibility with existing consumers. + SignatureV0 SignatureVersion = iota + + // SignatureV1 produces versioned timestamp-bound signatures: + // "v1,sha256=.))>". + // The timestamp is included in the HMAC input to prevent replay attacks. + // Receivers must verify freshness (e.g., reject timestamps older than 5 minutes). + SignatureV1 +) + // Defaults for Deliverer configuration. const ( defaultMaxConcurrency = 20 @@ -51,6 +66,7 @@ type Deliverer struct { decryptor SecretDecryptor maxConc int maxRetries int + sigVersion SignatureVersion } // Option configures a Deliverer at construction time. @@ -128,6 +144,21 @@ func WithSecretDecryptor(fn SecretDecryptor) Option { } } +// WithSignatureVersion selects the HMAC signing format for X-Webhook-Signature. +// The default is SignatureV0 (payload-only) for backward compatibility. +// SignatureV1 produces a versioned "v1,sha256=..." signature string that binds +// the event timestamp into the HMAC input, enabling replay protection. +// Receivers can enforce freshness using [VerifySignatureWithFreshness], or +// perform basic signature verification using [VerifySignature]. +// +// Migration path: switch to SignatureV1 only after all consumers have been +// updated to verify the "v1,sha256=..." format. +func WithSignatureVersion(v SignatureVersion) Option { + return func(d *Deliverer) { + d.sigVersion = v + } +} + // defaultHTTPClient creates an http.Client optimized for webhook delivery. // Connection pooling avoids TCP+TLS handshake overhead on repeated deliveries // to the same endpoint — critical at scale where hundreds of webhooks per @@ -202,6 +233,8 @@ func (d *Deliverer) Deliver(ctx context.Context, event *Event) error { return fmt.Errorf("webhook: list endpoints: %w", err) } + // Defensive filter: EndpointLister contract guarantees active-only, + // but guard against faulty implementations. active := filterActive(endpoints) if len(active) == 0 { d.log(ctx, log.LevelDebug, "no active endpoints for event", @@ -239,6 +272,8 @@ func (d *Deliverer) DeliverWithResults(ctx context.Context, event *Event) []Deli }} } + // Defensive filter: EndpointLister contract guarantees active-only, + // but guard against faulty implementations. active := filterActive(endpoints) if len(active) == 0 { return nil @@ -335,7 +370,7 @@ func (d *Deliverer) deliverToEndpoint( result := DeliveryResult{EndpointID: ep.ID} // --- SSRF validation + DNS pinning (single lookup, eliminates TOCTOU) --- - pinnedURL, originalHost, ssrfErr := resolveAndValidateIP(ctx, ep.URL) + pinnedURL, originalAuthority, sniHostname, ssrfErr := resolveAndValidateIP(ctx, ep.URL) if ssrfErr != nil { span.RecordError(ssrfErr) span.SetStatus(codes.Error, "SSRF blocked") @@ -379,7 +414,7 @@ func (d *Deliverer) deliverToEndpoint( } } - statusCode, err := d.doHTTP(ctx, pinnedURL, originalHost, event, secret) + statusCode, err := d.doHTTP(ctx, pinnedURL, originalAuthority, sniHostname, event, secret) result.StatusCode = statusCode if err != nil { @@ -408,7 +443,7 @@ func (d *Deliverer) deliverToEndpoint( // Non-retryable client errors — break immediately (except 429 Too Many Requests). if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError && statusCode != http.StatusTooManyRequests { result.Error = fmt.Errorf("webhook: non-retryable status %d", statusCode) - d.recordMetrics(ctx, ep.ID, false, statusCode, attempt+1) + d.recordMetrics(ctx, ep.ID, false, statusCode, result.Attempts) return result } @@ -438,10 +473,16 @@ func (d *Deliverer) deliverToEndpoint( // doHTTP builds and executes a single HTTP request to the (possibly pinned) URL. // Returns the status code and any transport-level error. +// +// originalAuthority is the full host:port authority from the original URL, used +// as the HTTP Host header to preserve explicit non-default ports. +// sniHostname is the bare hostname (port stripped), used for TLS SNI and +// certificate verification when the URL has been rewritten to a pinned IP. func (d *Deliverer) doHTTP( ctx context.Context, pinnedURL string, - originalHost string, + originalAuthority string, + sniHostname string, event *Event, secret string, ) (int, error) { @@ -455,21 +496,22 @@ func (d *Deliverer) doHTTP( req.Header.Set("X-Webhook-Timestamp", strconv.FormatInt(event.Timestamp, 10)) // When the URL was rewritten to use the pinned IP, set the Host header - // for virtual hosting and use TLSClientConfig.ServerName so TLS SNI and - // certificate verification use the original hostname (not the IP). - if originalHost != "" { - req.Host = originalHost + // with the original authority (host:port) for virtual hosting, and use + // TLSClientConfig.ServerName with the bare hostname for TLS SNI and + // certificate verification (not the IP). + if originalAuthority != "" { + req.Host = originalAuthority } if secret != "" { - sig := computeHMAC(event.Payload, secret) - req.Header.Set("X-Webhook-Signature", "sha256="+sig) + sig := d.computeSignature(event.Payload, event.Timestamp, secret) + req.Header.Set("X-Webhook-Signature", sig) } client := d.client - if originalHost != "" && strings.HasPrefix(pinnedURL, "https://") { - client = d.httpsClientForPinnedIP(originalHost) + if sniHostname != "" && strings.HasPrefix(pinnedURL, "https://") { + client = d.httpsClientForPinnedIP(sniHostname) } resp, err := client.Do(req) @@ -506,18 +548,25 @@ func (d *Deliverer) resolveSecret(raw string) (string, error) { return plaintext, nil } +// computeSignature dispatches to the appropriate HMAC format based on the +// configured signature version. +func (d *Deliverer) computeSignature(payload []byte, timestamp int64, secret string) string { + switch d.sigVersion { + case SignatureV1: + return computeHMACv1(payload, timestamp, secret) + default: + return "sha256=" + computeHMAC(payload, secret) + } +} + // computeHMAC returns the hex-encoded HMAC-SHA256 of payload using the given secret. +// This is the legacy (v0) format that signs the raw payload only. // -// Design note — timestamp not included in signature (by intent): +// Design note — timestamp not included in v0 signature (by intent): // The signature covers the raw payload only, not the X-Webhook-Timestamp value. -// Some industry integrations (e.g., Stripe) sign "timestamp.payload" to bind the -// timestamp to the signature and prevent replay attacks. We intentionally do not -// do this because: (a) changing the input format would silently break all existing -// consumers who already verify "sha256=HMAC(payload)" and (b) replay protection -// at the application layer (e.g., idempotency keys, short timestamp windows) is -// the responsibility of the receiving service. Receivers who need replay protection -// should validate that X-Webhook-Timestamp is within an acceptable window (e.g., -// ±5 minutes) independently of the signature check. +// This format is maintained for backward compatibility. New deployments should +// prefer SignatureV1 (via WithSignatureVersion) which binds the timestamp into +// the HMAC input to enable replay protection. func computeHMAC(payload []byte, secret string) string { mac := hmac.New(sha256.New, []byte(secret)) mac.Write(payload) @@ -525,6 +574,91 @@ func computeHMAC(payload []byte, secret string) string { return hex.EncodeToString(mac.Sum(nil)) } +// computeHMACv1 returns a versioned signature string: +// +// "v1,sha256=.", secret))>" +// +// The version prefix "v1:" followed by the decimal timestamp and a dot separator +// are prepended to the payload before computing the HMAC, binding the timestamp +// to the signature. Receivers must parse the "v1," prefix, extract the timestamp +// from the X-Webhook-Timestamp header, reconstruct the signing input, and verify +// the HMAC before accepting the webhook. +func computeHMACv1(payload []byte, timestamp int64, secret string) string { + ts := strconv.FormatInt(timestamp, 10) + + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write([]byte("v1:")) + mac.Write([]byte(ts)) + mac.Write([]byte(".")) + mac.Write(payload) + + return "v1,sha256=" + hex.EncodeToString(mac.Sum(nil)) +} + +// VerifySignature verifies a webhook signature by auto-detecting the version +// from the signature string format. Returns nil on valid signature, or an error +// describing the mismatch. +// +// Supported formats: +// - "sha256=" — v0 (payload-only) +// - "v1,sha256=" — v1 (timestamp-bound) +// +// For v1 signatures, the timestamp parameter is required to reconstruct the +// signing input. For v0 signatures, the timestamp is ignored. +func VerifySignature(payload []byte, timestamp int64, secret, signature string) error { + switch { + case strings.HasPrefix(signature, "v1,"): + expected := computeHMACv1(payload, timestamp, secret) + if !hmac.Equal([]byte(signature), []byte(expected)) { + return errors.New("webhook: v1 signature mismatch") + } + + return nil + + case strings.HasPrefix(signature, "sha256="): + expected := "sha256=" + computeHMAC(payload, secret) + if !hmac.Equal([]byte(signature), []byte(expected)) { + return errors.New("webhook: v0 signature mismatch") + } + + return nil + + default: + return errors.New("webhook: unrecognized signature format") + } +} + +// VerifySignatureWithFreshness verifies a v1 webhook signature and additionally +// checks that the timestamp is within the given tolerance window from now. +// This provides replay protection: even if an attacker captures a valid +// payload+signature pair, it becomes invalid after the tolerance window expires. +// +// For v0 ("sha256=...") signatures, freshness cannot be enforced because the +// timestamp is not covered by the HMAC. Callers receiving v0 signatures should +// use VerifySignature and implement replay protection independently (e.g., +// idempotency keys or event-ID tracking). +func VerifySignatureWithFreshness(payload []byte, timestamp int64, secret, signature string, tolerance time.Duration) error { + if err := VerifySignature(payload, timestamp, secret, signature); err != nil { + return err + } + + // Freshness check only applies to v1 where the timestamp is signed. + if strings.HasPrefix(signature, "v1,") { + eventTime := time.Unix(timestamp, 0) + delta := time.Since(eventTime) + + if delta < 0 { + delta = -delta + } + + if delta > tolerance { + return fmt.Errorf("webhook: timestamp outside tolerance window (%s > %s)", delta.Truncate(time.Second), tolerance) + } + } + + return nil +} + // filterActive returns only endpoints where Active is true. func filterActive(endpoints []Endpoint) []Endpoint { active := make([]Endpoint, 0, len(endpoints)) @@ -605,17 +739,21 @@ func (d *Deliverer) httpsClientForPinnedIP(originalHost string) *http.Client { return &clone } -// sanitizeURL strips query parameters from a URL before logging to prevent -// credential leakage. Webhook URLs may carry tokens in query params -// (e.g., ?token=..., ?api_key=...) that must not appear in log output. -// On parse failure the raw string is returned unchanged so no log line is lost. +// sanitizeURL strips query parameters and userinfo from a URL before logging +// to prevent credential leakage. Webhook URLs may carry tokens in query params +// (e.g., ?token=..., ?api_key=...) or credentials in the userinfo component +// (e.g., https://user:pass@host/...) that must not appear in log output. +// On parse failure a safe placeholder is returned instead of the raw input +// to avoid leaking credentials embedded in malformed URLs. func sanitizeURL(rawURL string) string { u, err := url.Parse(rawURL) if err != nil { - return rawURL + return "[invalid-url]" } u.RawQuery = "" + u.User = nil + u.Fragment = "" return u.String() } diff --git a/commons/webhook/deliverer_test.go b/commons/webhook/deliverer_test.go index ef4fe7dc..bcca3729 100644 --- a/commons/webhook/deliverer_test.go +++ b/commons/webhook/deliverer_test.go @@ -14,6 +14,7 @@ import ( "net/http" "net/http/httptest" "strconv" + "strings" "sync" "sync/atomic" "testing" @@ -900,3 +901,237 @@ func TestResolveSecret(t *testing.T) { }) } } + +// --------------------------------------------------------------------------- +// sanitizeURL — credential redaction +// --------------------------------------------------------------------------- + +func TestSanitizeURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw string + want string + }{ + { + name: "strips query params", + raw: "https://example.com/webhook?token=secret&api_key=abc", + want: "https://example.com/webhook", + }, + { + name: "strips userinfo", + raw: "https://user:pass@example.com/webhook", + want: "https://example.com/webhook", + }, + { + name: "strips both userinfo and query", + raw: "https://admin:hunter2@example.com/hook?key=val", + want: "https://example.com/hook", + }, + { + name: "no credentials passes through", + raw: "https://example.com/webhook", + want: "https://example.com/webhook", + }, + { + name: "strips fragment", + raw: "https://example.com/webhook#access_token=secret", + want: "https://example.com/webhook", + }, + { + name: "strips fragment with query and userinfo", + raw: "https://user:pass@example.com/hook?key=val#frag", + want: "https://example.com/hook", + }, + { + name: "invalid URL returns placeholder", + raw: "://bad\x7f", + want: "[invalid-url]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := sanitizeURL(tt.raw) + assert.Equal(t, tt.want, got) + }) + } +} + +// --------------------------------------------------------------------------- +// Versioned signatures — v1 HMAC with timestamp binding +// --------------------------------------------------------------------------- + +func TestComputeHMACv1(t *testing.T) { + t.Parallel() + + payload := []byte(`{"id":"123"}`) + secret := "test-secret" + timestamp := int64(1700000000) + + // Independently computed expected value. + // Wire format: HMAC-SHA256("v1:1700000000.{\"id\":\"123\"}", "test-secret") + // prefixed with "v1,sha256=". + const expected = "v1,sha256=49d0851d6baa34655d12dfba153c748db08b84830d4cb01780d833956c59844c" + + sig := computeHMACv1(payload, timestamp, secret) + assert.Equal(t, expected, sig, + "v1 signature must match independently computed HMAC-SHA256 of 'v1:.'") + + // Verify it is deterministic. + sig2 := computeHMACv1(payload, timestamp, secret) + assert.Equal(t, sig, sig2, "same input must produce same signature") + + // Different timestamp must produce different signature. + sig3 := computeHMACv1(payload, timestamp+1, secret) + assert.NotEqual(t, sig, sig3, "different timestamp must produce different signature") +} + +func TestComputeSignature_V0Default(t *testing.T) { + t.Parallel() + + d := &Deliverer{sigVersion: SignatureV0} + payload := []byte(`{"id":"123"}`) + timestamp := int64(1700000000) + secret := "test-secret" + + sig := d.computeSignature(payload, timestamp, secret) + assert.True(t, strings.HasPrefix(sig, "sha256="), + "v0 signature must start with 'sha256='") + assert.False(t, strings.HasPrefix(sig, "v1,"), + "v0 signature must NOT have v1 prefix") +} + +func TestComputeSignature_V1(t *testing.T) { + t.Parallel() + + d := &Deliverer{sigVersion: SignatureV1} + payload := []byte(`{"id":"123"}`) + timestamp := int64(1700000000) + secret := "test-secret" + + sig := d.computeSignature(payload, timestamp, secret) + assert.True(t, strings.HasPrefix(sig, "v1,sha256="), + "v1 signature must start with 'v1,sha256='") +} + +func TestVerifySignature(t *testing.T) { + t.Parallel() + + payload := []byte(`{"id":"123"}`) + secret := "test-secret" + timestamp := int64(1700000000) + + tests := []struct { + name string + sig string + wantErr string + }{ + { + name: "valid v0 signature", + sig: "sha256=" + computeHMAC(payload, secret), + }, + { + name: "valid v1 signature", + sig: computeHMACv1(payload, timestamp, secret), + }, + { + name: "invalid v0 signature", + sig: "sha256=deadbeef", + wantErr: "v0 signature mismatch", + }, + { + name: "invalid v1 signature", + sig: "v1,sha256=deadbeef", + wantErr: "v1 signature mismatch", + }, + { + name: "unrecognized format", + sig: "v99,md5=abc", + wantErr: "unrecognized signature format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := VerifySignature(payload, timestamp, secret, tt.sig) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestVerifySignatureWithFreshness(t *testing.T) { + t.Parallel() + + payload := []byte(`{"id":"123"}`) + secret := "test-secret" + now := time.Now().Unix() + + // Fresh v1 signature (timestamp = now). + freshSig := computeHMACv1(payload, now, secret) + + err := VerifySignatureWithFreshness(payload, now, secret, freshSig, 5*time.Minute) + assert.NoError(t, err, "fresh v1 signature within tolerance must pass") + + // Stale v1 signature (timestamp = 1 hour ago). + staleTS := now - 3600 + staleSig := computeHMACv1(payload, staleTS, secret) + + err = VerifySignatureWithFreshness(payload, staleTS, secret, staleSig, 5*time.Minute) + require.Error(t, err) + assert.Contains(t, err.Error(), "outside tolerance window") + + // V0 signature — freshness check is skipped (timestamp not signed). + v0Sig := "sha256=" + computeHMAC(payload, secret) + + err = VerifySignatureWithFreshness(payload, staleTS, secret, v0Sig, 5*time.Minute) + assert.NoError(t, err, "v0 signature skips freshness check") +} + +func TestDeliver_V1Signature_EndToEnd(t *testing.T) { + t.Parallel() + + secret := "v1-secret" + event := newTestEvent() + + var gotSig string + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSig = r.Header.Get("X-Webhook-Signature") + w.WriteHeader(http.StatusOK) + })) + + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-v1", URL: pubURL, Secret: secret, Active: true}, + }, + } + + d := NewDeliverer(lister, + WithMaxRetries(0), + WithSignatureVersion(SignatureV1), + WithHTTPClient(ssrfBypassClient(srv.Listener.Addr().String())), + ) + + results := d.DeliverWithResults(context.Background(), event) + require.Len(t, results, 1) + require.True(t, results[0].Success) + + // The signature should be v1 format. + assert.True(t, strings.HasPrefix(gotSig, "v1,sha256="), + "v1 deliverer must produce v1 signature format, got: %s", gotSig) + + // Verify the signature is correct. + err := VerifySignature(event.Payload, event.Timestamp, secret, gotSig) + assert.NoError(t, err, "v1 signature from deliverer must verify correctly") +} diff --git a/commons/webhook/doc.go b/commons/webhook/doc.go index ded26825..06de7ba8 100644 --- a/commons/webhook/doc.go +++ b/commons/webhook/doc.go @@ -6,8 +6,8 @@ // - Two-layer SSRF protection: pre-resolution IP validation + DNS-pinned delivery // - DNS rebinding prevention via resolved IP pinning // - Blocks private, loopback, and link-local IP ranges -// - HMAC-SHA256 signature in X-Webhook-Signature header (sha256=HEX) -// - URL query parameters stripped from log output to prevent credential leakage +// - HMAC-SHA256 signature in X-Webhook-Signature header (versioned format) +// - URL query parameters and userinfo stripped from log output to prevent credential leakage // // # Delivery model // @@ -15,16 +15,46 @@ // - Exponential backoff with jitter (1s, 2s, 4s, ...) // - Per-endpoint retry with configurable max attempts (default: 3) // -// # HMAC signature scope +// # HMAC signature versions // -// The X-Webhook-Signature header carries HMAC-SHA256 computed over the raw -// payload bytes only. X-Webhook-Timestamp is sent as a separate informational -// header but is NOT covered by the HMAC — an attacker who captures a valid -// payload+signature pair can replay it with a fresh timestamp. +// The X-Webhook-Signature header supports two formats, selectable via +// WithSignatureVersion at Deliverer construction time: // -// Receivers who need replay protection must implement it independently, for -// example by tracking event IDs or embedding a nonce in the payload itself. -// Timestamp-window checks alone are insufficient because the timestamp is -// unsigned. Including the timestamp in the HMAC would be a breaking change -// for existing consumers. +// ## v0 (default — backward compatible) +// +// Format: "sha256=" +// +// The HMAC covers the raw payload bytes only. X-Webhook-Timestamp is sent as +// a separate informational header but is NOT covered by the HMAC. An attacker +// who captures a valid payload+signature pair can replay it. Receivers must +// implement replay protection independently (e.g., event-ID tracking, +// idempotency keys). +// +// ## v1 (recommended for new deployments) +// +// Format: "v1,sha256=.", secret))>" +// +// The HMAC input includes a version prefix, the decimal Unix-epoch timestamp +// from X-Webhook-Timestamp, and the raw payload — binding the timestamp to +// the signature. Receivers must: +// 1. Parse the "v1," prefix from X-Webhook-Signature. +// 2. Read X-Webhook-Timestamp and reconstruct the signing input as +// "v1:.". +// 3. Compute HMAC-SHA256 with the shared secret and compare (constant-time). +// 4. Reject timestamps outside an acceptable clock-skew window (e.g., +/- 5 min) +// or track event IDs / nonces to prevent replay. +// +// Use VerifySignature or VerifySignatureWithFreshness for receiver-side +// verification — both auto-detect the version from the signature string. +// +// # Migration from v0 to v1 +// +// Because this is a library used by multiple services, the default remains v0 +// to avoid breaking existing consumers. To migrate: +// 1. Update all webhook receivers to accept both v0 and v1 formats (use +// VerifySignature which auto-detects the version). +// 2. Once all receivers are updated, switch senders to v1 by constructing +// the Deliverer with WithSignatureVersion(SignatureV1). +// 3. After a transition period, receivers may optionally reject v0 signatures +// to enforce replay protection. package webhook diff --git a/commons/webhook/ssrf.go b/commons/webhook/ssrf.go index 46637131..a5e2eaad 100644 --- a/commons/webhook/ssrf.go +++ b/commons/webhook/ssrf.go @@ -2,151 +2,48 @@ package webhook import ( "context" + "errors" "fmt" - "net" - "net/url" - "strings" -) - -// cidr4 constructs a static *net.IPNet for an IPv4 CIDR. Using a helper keeps -// each entry in the table below to a single readable line. All four octets are -// accepted so future CIDR additions don't require changing the signature. -// -//nolint:unparam // d is currently always 0 for these network addresses; kept for generality. -func cidr4(a, b, c, d byte, ones int) *net.IPNet { - return &net.IPNet{ - IP: net.IPv4(a, b, c, d).To4(), - Mask: net.CIDRMask(ones, 32), - } -} -// cgnatBlock is the CGNAT (Carrier-Grade NAT) range defined by RFC 6598. -// Cloud providers frequently use this range for internal routing, so it must -// be blocked to prevent SSRF via addresses like 100.64.0.1. -// -//nolint:gochecknoglobals // package-level CIDR block is intentional for SSRF protection -var cgnatBlock = cidr4(100, 64, 0, 0, 10) - -// additionalBlockedRanges holds CIDR blocks that are not covered by the -// standard net.IP predicates (IsPrivate, IsLoopback, etc.) but must be -// blocked to prevent SSRF attacks. -// -// All entries are compile-time-constructed net.IPNet literals — no runtime -// string parsing, no init() required, and typos surface as test failures -// rather than startup panics. -// -//nolint:gochecknoglobals // package-level slice is intentional for SSRF protection -var additionalBlockedRanges = []*net.IPNet{ - cidr4(0, 0, 0, 0, 8), // 0.0.0.0/8 — "this network" (RFC 1122 §3.2.1.3) - cidr4(192, 0, 0, 0, 24), // 192.0.0.0/24 — IETF protocol assignments (RFC 6890) - cidr4(192, 0, 2, 0, 24), // 192.0.2.0/24 — TEST-NET-1 documentation (RFC 5737) - cidr4(198, 18, 0, 0, 15), // 198.18.0.0/15 — benchmarking (RFC 2544) - cidr4(198, 51, 100, 0, 24), // 198.51.100.0/24 — TEST-NET-2 documentation (RFC 5737) - cidr4(203, 0, 113, 0, 24), // 203.0.113.0/24 — TEST-NET-3 documentation (RFC 5737) - cidr4(240, 0, 0, 0, 4), // 240.0.0.0/4 — reserved/future use (RFC 1112) -} + libSSRF "github.com/LerianStudio/lib-commons/v4/commons/security/ssrf" +) -// resolveAndValidateIP performs a single DNS lookup for the hostname in rawURL, -// validates every resolved IP against the SSRF blocklist, and returns a new URL -// with the hostname replaced by the first resolved IP (DNS pinning). -// -// Combining validation and pinning into one lookup eliminates the TOCTOU window -// that exists when validateResolvedIP and pinResolvedIP are called sequentially: -// a DNS rebinding attack could change the record between those two calls, causing -// the pinned IP to differ from the validated one. +// resolveAndValidateIP delegates to the canonical [libSSRF.ResolveAndValidate] +// function and maps its result back to the 4-return-value signature expected by +// [deliverToEndpoint]. This keeps the webhook package's internal call-sites +// stable while removing the duplicated SSRF implementation. // // On success it returns: -// - pinnedURL — original URL with the hostname replaced by the first resolved IP. -// - originalHost — the original hostname, for use as the HTTP Host header (TLS SNI). +// - pinnedURL — original URL with the hostname replaced by the first safe resolved IP. +// - originalAuthority — the original host:port authority, suitable for the HTTP Host header. +// - sniHostname — the bare hostname (port stripped), suitable for TLS SNI / certificate verification. // -// DNS lookup failures are fail-closed: if the hostname cannot be resolved, the -// URL is rejected. When no resolved IP can be parsed from the DNS response the -// URL is considered unresolvable and an error is returned. -func resolveAndValidateIP(ctx context.Context, rawURL string) (pinnedURL string, originalHost string, err error) { - u, err := url.Parse(rawURL) - if err != nil { - return "", "", fmt.Errorf("%w: %w", ErrInvalidURL, err) - } - - scheme := strings.ToLower(u.Scheme) - if scheme != "http" && scheme != "https" { - return "", "", fmt.Errorf("%w: scheme %q not allowed", ErrSSRFBlocked, scheme) - } - - host := u.Hostname() - if host == "" { - return "", "", fmt.Errorf("%w: empty hostname", ErrInvalidURL) - } - - ips, dnsErr := net.DefaultResolver.LookupHost(ctx, host) - if dnsErr != nil { - return "", "", fmt.Errorf("%w: DNS lookup failed for %s: %w", ErrSSRFBlocked, host, dnsErr) +// Error mapping: +// - [libSSRF.ErrBlocked] → [ErrSSRFBlocked] +// - [libSSRF.ErrDNSFailed] → [ErrSSRFBlocked] +// - [libSSRF.ErrInvalidURL] → [ErrInvalidURL] +func resolveAndValidateIP(ctx context.Context, rawURL string) (pinnedURL, originalAuthority, sniHostname string, err error) { + result, resolveErr := libSSRF.ResolveAndValidate(ctx, rawURL) + if resolveErr != nil { + return "", "", "", mapSSRFError(resolveErr) } - if len(ips) == 0 { - return "", "", fmt.Errorf("%w: DNS returned no addresses for %s", ErrSSRFBlocked, host) - } - - var firstValidIP string - - for _, ipStr := range ips { - ip := net.ParseIP(ipStr) - if ip == nil { - continue - } - - if isPrivateIP(ip) { - return "", "", fmt.Errorf("%w: resolved IP %s is private/loopback", ErrSSRFBlocked, ipStr) - } - - if firstValidIP == "" { - firstValidIP = ipStr - } - } - - if firstValidIP == "" { - return "", "", fmt.Errorf("%w: no valid IPs resolved for %s", ErrInvalidURL, host) - } - - // Pin to first valid resolved IP to prevent DNS rebinding across retries. - port := u.Port() + return result.PinnedURL, result.Authority, result.SNIHostname, nil +} +// mapSSRFError translates sentinel errors from the canonical ssrf package into +// the webhook package's error types so that existing callers (and tests) that +// check [errors.Is] against [ErrSSRFBlocked] / [ErrInvalidURL] continue to +// work without modification. +func mapSSRFError(err error) error { switch { - case port != "": - u.Host = net.JoinHostPort(firstValidIP, port) - case strings.Contains(firstValidIP, ":"): - // Bare IPv6 literal must be bracket-wrapped for url.URL.Host. - u.Host = "[" + firstValidIP + "]" + case errors.Is(err, libSSRF.ErrInvalidURL): + return fmt.Errorf("%w: %w", ErrInvalidURL, err) + case errors.Is(err, libSSRF.ErrBlocked): + return fmt.Errorf("%w: %w", ErrSSRFBlocked, err) + case errors.Is(err, libSSRF.ErrDNSFailed): + return fmt.Errorf("%w: %w", ErrSSRFBlocked, err) default: - u.Host = firstValidIP + return fmt.Errorf("%w: %w", ErrSSRFBlocked, err) } - - return u.String(), host, nil -} - -// isPrivateIP reports whether ip is in a private, loopback, link-local, -// unspecified, CGNAT, multicast, or other reserved range that must not be -// contacted by webhook delivery (SSRF protection). -// -// In addition to the ranges covered by the standard net.IP predicates, this -// function checks the additionalBlockedRanges slice which covers RFC-defined -// special-purpose blocks not included in Go's net package. -func isPrivateIP(ip net.IP) bool { - if ip.IsLoopback() || - ip.IsPrivate() || - ip.IsLinkLocalUnicast() || - ip.IsLinkLocalMulticast() || - ip.IsMulticast() || - ip.IsUnspecified() || - cgnatBlock.Contains(ip) { - return true - } - - for _, block := range additionalBlockedRanges { - if block.Contains(ip) { - return true - } - } - - return false } diff --git a/commons/webhook/ssrf_test.go b/commons/webhook/ssrf_test.go index 456cb2dd..de45f395 100644 --- a/commons/webhook/ssrf_test.go +++ b/commons/webhook/ssrf_test.go @@ -5,106 +5,42 @@ package webhook import ( "context" "errors" - "net" + "fmt" "testing" + libSSRF "github.com/LerianStudio/lib-commons/v4/commons/security/ssrf" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // --------------------------------------------------------------------------- -// isPrivateIP — pure function, no DNS, safe to unit-test exhaustively. +// resolveAndValidateIP — URL validation (no real DNS unless noted) // --------------------------------------------------------------------------- -func TestIsPrivateIP(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - ip string - private bool - }{ - // Loopback - {name: "IPv4 loopback", ip: "127.0.0.1", private: true}, - {name: "IPv6 loopback", ip: "::1", private: true}, - - // RFC 1918 private ranges - {name: "10.0.0.0/8", ip: "10.0.0.1", private: true}, - {name: "10.255.255.255", ip: "10.255.255.255", private: true}, - {name: "172.16.0.0/12", ip: "172.16.0.1", private: true}, - {name: "172.31.255.255", ip: "172.31.255.255", private: true}, - {name: "192.168.0.0/16", ip: "192.168.0.1", private: true}, - {name: "192.168.255.255", ip: "192.168.255.255", private: true}, - - // Link-local unicast - {name: "IPv4 link-local", ip: "169.254.1.1", private: true}, - {name: "IPv6 link-local unicast", ip: "fe80::1", private: true}, - - // Link-local multicast - {name: "IPv4 link-local multicast", ip: "224.0.0.1", private: true}, - - // Unspecified addresses (0.0.0.0 / ::) - {name: "IPv4 unspecified", ip: "0.0.0.0", private: true}, - {name: "IPv6 unspecified", ip: "::", private: true}, - - // CGNAT range (RFC 6598): 100.64.0.0/10 - {name: "CGNAT low end", ip: "100.64.0.1", private: true}, - {name: "CGNAT mid range", ip: "100.100.100.100", private: true}, - {name: "CGNAT high end", ip: "100.127.255.254", private: true}, - // Just below CGNAT range — should NOT be private - {name: "100.63.255.255 not CGNAT", ip: "100.63.255.255", private: false}, - // Just above CGNAT range — should NOT be private - {name: "100.128.0.1 not CGNAT", ip: "100.128.0.1", private: false}, - - // Public IPs — should NOT be private - {name: "Google DNS", ip: "8.8.8.8", private: false}, - {name: "Cloudflare", ip: "1.1.1.1", private: false}, - {name: "Public IPv6", ip: "2001:4860:4860::8888", private: false}, - - // Edge: 172.15.x is NOT private (just below 172.16) - {name: "172.15.255.255 not private", ip: "172.15.255.255", private: false}, - // Edge: 172.32.x is NOT private (just above 172.31) - {name: "172.32.0.1 not private", ip: "172.32.0.1", private: false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ip := net.ParseIP(tt.ip) - assert.NotNil(t, ip, "failed to parse IP: %s", tt.ip) - assert.Equal(t, tt.private, isPrivateIP(ip), - "isPrivateIP(%s) = %v, want %v", tt.ip, isPrivateIP(ip), tt.private) - }) - } -} - -// --------------------------------------------------------------------------- -// resolveAndValidateIP — URL parsing edge cases only. We do NOT hit real DNS. -// --------------------------------------------------------------------------- - -func TestResolveAndValidateIP_InvalidURLMalformed(t *testing.T) { +// TestResolveAndValidateIP_InvalidURL checks that a completely malformed URL +// is rejected with ErrInvalidURL. +func TestResolveAndValidateIP_InvalidURL(t *testing.T) { t.Parallel() - _, _, err := resolveAndValidateIP(context.Background(), "://missing-scheme") - assert.Error(t, err) + _, _, _, err := resolveAndValidateIP(context.Background(), "://no-scheme") + require.Error(t, err) assert.ErrorIs(t, err, ErrInvalidURL) } -func TestResolveAndValidateIP_EmptyHostnameHTTP(t *testing.T) { +// TestResolveAndValidateIP_EmptyHostname checks that an empty hostname is +// rejected with ErrInvalidURL and a descriptive message. +func TestResolveAndValidateIP_EmptyHostname(t *testing.T) { t.Parallel() - _, _, err := resolveAndValidateIP(context.Background(), "http://") - assert.Error(t, err) + _, _, _, err := resolveAndValidateIP(context.Background(), "http://") + require.Error(t, err) assert.ErrorIs(t, err, ErrInvalidURL) assert.Contains(t, err.Error(), "empty hostname") } -// --------------------------------------------------------------------------- -// resolveAndValidateIP — scheme validation (blocks non-HTTP/HTTPS schemes). -// --------------------------------------------------------------------------- - -func TestResolveAndValidateIP_UnsupportedSchemes(t *testing.T) { +// TestResolveAndValidateIP_BlockedSchemes verifies that non-HTTP/HTTPS schemes +// are rejected before DNS lookup. +func TestResolveAndValidateIP_BlockedSchemes(t *testing.T) { t.Parallel() tests := []struct { @@ -122,62 +58,7 @@ func TestResolveAndValidateIP_UnsupportedSchemes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, _, err := resolveAndValidateIP(context.Background(), tt.url) - assert.Error(t, err) - assert.ErrorIs(t, err, ErrSSRFBlocked) - }) - } -} - -func TestResolveAndValidateIP_AllowedSchemes(t *testing.T) { - t.Parallel() - - // These schemes should pass the scheme check (they may still fail DNS - // resolution, but they should NOT fail with ErrSSRFBlocked for scheme). - tests := []struct { - name string - url string - }{ - {name: "http scheme", url: "http://example.com"}, - {name: "https scheme", url: "https://example.com"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - _, _, err := resolveAndValidateIP(context.Background(), tt.url) - // The error, if any, should NOT be ErrSSRFBlocked for scheme reasons. - if err != nil { - assert.False(t, errors.Is(err, ErrSSRFBlocked), "http/https should not be SSRF-blocked: %v", err) - } - }) - } -} - -// --------------------------------------------------------------------------- -// resolveAndValidateIP — URL parsing and scheme blocking (no DNS) -// --------------------------------------------------------------------------- - -// TestResolveAndValidateIP_InvalidScheme checks that non-HTTP/HTTPS schemes are -// rejected before DNS lookup. -func TestResolveAndValidateIP_InvalidScheme(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - url string - }{ - {name: "gopher scheme", url: "gopher://example.com"}, - {name: "file scheme", url: "file:///etc/passwd"}, - {name: "ftp scheme", url: "ftp://example.com/file"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - _, _, err := resolveAndValidateIP(context.Background(), tt.url) + _, _, _, err := resolveAndValidateIP(context.Background(), tt.url) require.Error(t, err) assert.ErrorIs(t, err, ErrSSRFBlocked, "non-HTTP/HTTPS scheme must return ErrSSRFBlocked") @@ -185,90 +66,61 @@ func TestResolveAndValidateIP_InvalidScheme(t *testing.T) { } } -// TestResolveAndValidateIP_EmptyHostname checks that an empty hostname is rejected. -func TestResolveAndValidateIP_EmptyHostname(t *testing.T) { - t.Parallel() - - _, _, err := resolveAndValidateIP(context.Background(), "http://") - require.Error(t, err) - assert.ErrorIs(t, err, ErrInvalidURL) - assert.Contains(t, err.Error(), "empty hostname") -} - -// TestResolveAndValidateIP_InvalidURL checks that a completely malformed URL -// is rejected with ErrInvalidURL. -func TestResolveAndValidateIP_InvalidURL(t *testing.T) { - t.Parallel() - - _, _, err := resolveAndValidateIP(context.Background(), "://no-scheme") - require.Error(t, err) - assert.ErrorIs(t, err, ErrInvalidURL) -} - -// TestResolveAndValidateIP_PrivateIP confirms that hostnames that resolve to a -// loopback/private address are blocked. "localhost" resolves to 127.0.0.1 -// (loopback), which must trigger ErrSSRFBlocked. -// -// This test does perform a real DNS lookup for "localhost". On all POSIX -// systems "localhost" is defined in /etc/hosts as 127.0.0.1, so the lookup -// is local and requires no network access. Environments that strip /etc/hosts -// may see the DNS fall-back path (original URL returned, no error), in which -// case the test is skipped gracefully. -func TestResolveAndValidateIP_PrivateIP(t *testing.T) { +// TestResolveAndValidateIP_BlockedHostname confirms that hostnames rejected by +// hostname-level SSRF validation are mapped to ErrSSRFBlocked. +func TestResolveAndValidateIP_BlockedHostname(t *testing.T) { t.Parallel() - // Probe the local DNS before asserting — skip if localhost doesn't resolve. - addrs, lookupErr := net.LookupHost("localhost") - if lookupErr != nil || len(addrs) == 0 { - t.Skip("localhost DNS lookup failed or returned no results — skipping SSRF private-IP test") - } - - _, _, err := resolveAndValidateIP(context.Background(), "http://localhost") + _, _, _, err := resolveAndValidateIP(context.Background(), "http://localhost") require.Error(t, err) assert.ErrorIs(t, err, ErrSSRFBlocked, - "localhost (127.0.0.1) must be blocked as a private/loopback address") + "localhost must be blocked by hostname-level SSRF protection") } // --------------------------------------------------------------------------- -// isPrivateIP — additional blocked ranges (RFC-defined special-purpose blocks) +// mapSSRFError — sentinel error translation (fail-closed for unknown errors) // --------------------------------------------------------------------------- -// TestResolveAndValidateIP_AllBlockedRanges tests isPrivateIP directly against -// representative IPs from every additional CIDR block listed in ssrf.go: -// -// - 0.0.0.1 → 0.0.0.0/8 "this network" (RFC 1122 §3.2.1.3) -// - 192.0.0.1 → 192.0.0.0/24 IETF protocol assignments (RFC 6890) -// - 192.0.2.1 → 192.0.2.0/24 TEST-NET-1 documentation (RFC 5737) -// - 198.18.0.1 → 198.18.0.0/15 benchmarking (RFC 2544) -// - 198.51.100.1 → 198.51.100.0/24 TEST-NET-2 documentation (RFC 5737) -// - 203.0.113.1 → 203.0.113.0/24 TEST-NET-3 documentation (RFC 5737) -// - 240.0.0.1 → 240.0.0.0/4 reserved / future use (RFC 1112) -// - 239.255.255.255 → multicast (net.IP.IsMulticast) -func TestResolveAndValidateIP_AllBlockedRanges(t *testing.T) { +// TestMapSSRFError verifies that all four branches of mapSSRFError translate +// canonical SSRF sentinel errors into the webhook package's error types, and +// that unrecognized errors are mapped to ErrSSRFBlocked (fail-closed). +func TestMapSSRFError(t *testing.T) { t.Parallel() tests := []struct { - name string - ip string + name string + input error + wantIs error }{ - {name: "0/8 this-network", ip: "0.0.0.1"}, - {name: "192.0.0/24 IETF", ip: "192.0.0.1"}, - {name: "192.0.2/24 TEST-NET-1", ip: "192.0.2.1"}, - {name: "198.18/15 benchmarking", ip: "198.18.0.1"}, - {name: "198.51.100/24 TEST-NET-2", ip: "198.51.100.1"}, - {name: "203.0.113/24 TEST-NET-3", ip: "203.0.113.1"}, - {name: "240/4 reserved", ip: "240.0.0.1"}, - {name: "multicast 239.255.255.255", ip: "239.255.255.255"}, + { + name: "ErrInvalidURL maps to webhook ErrInvalidURL", + input: fmt.Errorf("validation: %w", libSSRF.ErrInvalidURL), + wantIs: ErrInvalidURL, + }, + { + name: "ErrBlocked maps to webhook ErrSSRFBlocked", + input: fmt.Errorf("blocked: %w", libSSRF.ErrBlocked), + wantIs: ErrSSRFBlocked, + }, + { + name: "ErrDNSFailed maps to webhook ErrSSRFBlocked", + input: fmt.Errorf("dns: %w", libSSRF.ErrDNSFailed), + wantIs: ErrSSRFBlocked, + }, + { + name: "unknown error maps to webhook ErrSSRFBlocked (fail-closed)", + input: errors.New("unexpected internal error"), + wantIs: ErrSSRFBlocked, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - ip := net.ParseIP(tt.ip) - require.NotNil(t, ip, "test setup error: %s is not a valid IP", tt.ip) - assert.True(t, isPrivateIP(ip), - "isPrivateIP(%s) must return true for blocked range %s", tt.ip, tt.name) + err := mapSSRFError(tt.input) + require.Error(t, err) + assert.ErrorIs(t, err, tt.wantIs) }) } }