From 418ce48e29dbd14703d60f10aadb9987d17164ff Mon Sep 17 00:00:00 2001 From: Fred Amaral Date: Mon, 30 Mar 2026 11:32:41 -0300 Subject: [PATCH 1/9] fix: apply CodeRabbit auto-fixes for PR #419 Address 21 unresolved CodeRabbit review findings across certificate, systemplane, webhook, DLQ, and idempotency packages: - certificate: defensive copies in GetCertificate/TLSCertificate, nil-receiver safety for TLSCertificate, expanded DaysUntilExpiry godoc - systemplane: nil-safety guards for SnapSettingInt/SnapSettingBool, consolidate redundant state loads in supervisor reload, robust factory assertion in backend test, duplicate env var detection in catalog validation - webhook: redact URL userinfo in logs, versioned HMAC signature format (v1) with backward-compatible migration path, DNS-free scheme validation tests, preserve original authority for HTTP Host header, defensive filter comment, consistent metric reporting with result.Attempts - dlq: clone sources slice in WithSources, reject negative MaxRetries, validate tenant key segment before Redis routing - idempotency: binary-safe cached response with headers, fail-open error handling in handleDuplicate, accurate package documentation --- commons/certificate/certificate.go | 43 +++- commons/certificate/certificate_test.go | 17 +- commons/dlq/consumer.go | 8 +- commons/dlq/handler.go | 12 +- commons/net/http/idempotency/doc.go | 76 +++--- commons/net/http/idempotency/idempotency.go | 62 ++++- commons/systemplane/bootstrap/backend_test.go | 2 +- commons/systemplane/catalog/validate.go | 41 +++- commons/systemplane/domain/setting_helpers.go | 12 +- commons/systemplane/service/supervisor.go | 8 +- .../systemplane/service/supervisor_helpers.go | 8 +- commons/webhook/deliverer.go | 188 ++++++++++++--- commons/webhook/deliverer_test.go | 220 ++++++++++++++++++ commons/webhook/doc.go | 54 ++++- commons/webhook/ssrf.go | 45 ++-- commons/webhook/ssrf_test.go | 68 ++++-- 16 files changed, 718 insertions(+), 146 deletions(-) diff --git a/commons/certificate/certificate.go b/commons/certificate/certificate.go index 6906d57c..194780cc 100644 --- a/commons/certificate/certificate.go +++ b/commons/certificate/certificate.go @@ -151,9 +151,10 @@ func (m *Manager) Rotate(cert *x509.Certificate, key crypto.Signer, intermediate return nil } -// GetCertificate returns the current certificate, or nil if none is loaded. -// The returned pointer shares state with the manager; callers must treat it -// as read-only. Use [LoadFromFiles] + [Manager.Rotate] to replace it. +// GetCertificate returns a deep copy of the current certificate, or nil if +// none is loaded. The returned *x509.Certificate is an independent clone that +// callers may freely modify without affecting the manager's internal state. +// Use [LoadFromFiles] + [Manager.Rotate] to replace the managed certificate. func (m *Manager) GetCertificate() *x509.Certificate { if m == nil { return nil @@ -162,7 +163,7 @@ func (m *Manager) GetCertificate() *x509.Certificate { m.mu.RLock() defer m.mu.RUnlock() - return m.cert + return cloneCert(m.cert) } // GetSigner returns the current private key as a crypto.Signer, or nil if none is loaded. @@ -212,7 +213,10 @@ func (m *Manager) ExpiresAt() time.Time { } // DaysUntilExpiry returns the number of days until the certificate expires. -// Returns -1 if no certificate is loaded. +// It returns -1 when no certificate is loaded (nil receiver or no certificate +// configured via [NewManager]). Otherwise it returns the number of days until +// expiry, which may be negative for already-expired certificates (e.g. -3 +// means the certificate expired 3 days ago). func (m *Manager) DaysUntilExpiry() int { if m == nil { return -1 @@ -228,8 +232,9 @@ func (m *Manager) DaysUntilExpiry() int { // TLSCertificate returns a [tls.Certificate] built from the currently loaded // certificate chain and private key. Returns an empty [tls.Certificate] if no -// certificate is loaded. The Leaf field shares state with the manager; callers -// should treat it as read-only. +// certificate is loaded. Both the Certificate [][]byte chain and the Leaf are +// deep copies, so callers never receive references aliasing internal state. +// Safe to call on a nil receiver (returns an empty [tls.Certificate]). func (m *Manager) TLSCertificate() tls.Certificate { if m == nil { return tls.Certificate{} @@ -252,7 +257,7 @@ func (m *Manager) TLSCertificate() tls.Certificate { return tls.Certificate{ Certificate: chainCopy, PrivateKey: m.signer, - Leaf: m.cert, + Leaf: cloneCert(m.cert), } } @@ -388,3 +393,25 @@ func publicKeysMatch(certPublicKey, signerPublicKey any) bool { return bytes.Equal(certDER, signerDER) } + +// cloneCert returns a deep copy of cert by re-parsing its DER encoding. +// Returns nil when cert is nil. +func cloneCert(cert *x509.Certificate) *x509.Certificate { + if cert == nil { + return nil + } + + raw := make([]byte, len(cert.Raw)) + copy(raw, cert.Raw) + + // ParseCertificate is the canonical way to deep-copy an x509.Certificate. + // Errors here are unexpected (the DER was already parsed once), but we + // return nil rather than panicking to stay consistent with the package's + // nil-safe contract. + clone, err := x509.ParseCertificate(raw) + if err != nil { + return nil + } + + return clone +} diff --git a/commons/certificate/certificate_test.go b/commons/certificate/certificate_test.go index dba18d66..8b5ac909 100644 --- a/commons/certificate/certificate_test.go +++ b/commons/certificate/certificate_test.go @@ -860,20 +860,11 @@ func TestNilManager_SubTests(t *testing.T) { assert.Equal(t, -1, m.DaysUntilExpiry()) }) - t.Run("TLSCertificate returns empty", func(t *testing.T) { + t.Run("TLSCertificate returns empty on nil receiver", func(t *testing.T) { t.Parallel() - // TLSCertificate checks m == nil after acquiring the lock which it - // can't do on a nil receiver — production code checks m == nil inside - // the method body after the RLock, but since m is nil the lock call - // itself will panic. The method guards with `if m == nil` but the - // guard is AFTER the RLock. Let's verify by calling it safely. - // Actually the method body is: mu.RLock then `if m == nil` — a nil - // pointer dereference would occur. We document that TLSCertificate - // is NOT nil-receiver-safe (unlike the other methods) by asserting - // the returned value from a non-nil empty manager instead. - empty, emptyErr := NewManager("", "") - require.NoError(t, emptyErr) - tlsCert := empty.TLSCertificate() + // TLSCertificate is nil-receiver-safe: the nil guard precedes the + // RLock call, consistent with every other Manager method. + tlsCert := m.TLSCertificate() assert.Equal(t, tls.Certificate{}, tlsCert) }) } diff --git a/commons/dlq/consumer.go b/commons/dlq/consumer.go index e3c21000..55ed95bf 100644 --- a/commons/dlq/consumer.go +++ b/commons/dlq/consumer.go @@ -106,11 +106,15 @@ func WithBatchSize(n int) ConsumerOption { } } -// WithSources sets the DLQ source queue names to consume. +// WithSources sets the DLQ source queue names to consume. The input slice is +// cloned so that subsequent mutations by the caller do not race with the +// consumer's read loop. func WithSources(sources ...string) ConsumerOption { return func(c *Consumer) { if len(sources) > 0 { - c.cfg.Sources = sources + cloned := make([]string, len(sources)) + copy(cloned, sources) + c.cfg.Sources = cloned } } } diff --git a/commons/dlq/handler.go b/commons/dlq/handler.go index 63a92232..ac4f9c3f 100644 --- a/commons/dlq/handler.go +++ b/commons/dlq/handler.go @@ -172,7 +172,7 @@ func (h *Handler) Enqueue(ctx context.Context, msg *FailedMessage) error { msg.CreatedAt = time.Now().UTC() } - if msg.MaxRetries == 0 { + if msg.MaxRetries <= 0 { msg.MaxRetries = h.maxRetries } @@ -188,6 +188,16 @@ func (h *Handler) Enqueue(ctx context.Context, msg *FailedMessage) error { return fmt.Errorf("dlq: enqueue: tenant mismatch between message (%s) and context (%s)", effectiveTenant, ctxTenant) } + // Validate the effective tenant before using it to construct a Redis key. + // This prevents invalid tenant IDs from silently falling back to the global + // (non-tenant) key inside tenantScopedKeyForTenant, which would mix + // tenant-scoped messages into the global queue. + if effectiveTenant != "" { + if err := validateKeySegment("tenantID", effectiveTenant); err != nil { + return fmt.Errorf("dlq: enqueue: %w", err) + } + } + // Recalculate NextRetryAt only on initial enqueue. On re-enqueue the // consumer has already incremented RetryCount and the caller is // responsible for timing; we preserve their NextRetryAt or let the diff --git a/commons/net/http/idempotency/doc.go b/commons/net/http/idempotency/doc.go index dfd0c50c..a4bd92a6 100644 --- a/commons/net/http/idempotency/doc.go +++ b/commons/net/http/idempotency/doc.go @@ -6,37 +6,22 @@ // requests may execute more than once. Callers that require strict at-most-once // guarantees must pair this middleware with application-level safeguards. // -// The middleware uses the X-Idempotency request header combined with the tenant ID -// (from tenant-manager context) to form a composite Redis key. When a tenant ID is -// present, keys are scoped per-tenant to prevent cross-tenant collisions. When no -// tenant is in context (e.g., non-tenant-scoped routes), keys are still valid but -// are shared across the global namespace. Duplicate requests receive the original -// response with the X-Idempotency-Replayed header set to "true". +// # Key composition // -// # Quick start -// -// conn, err := redis.New(ctx, cfg) -// if err != nil { -// log.Fatal(err) -// } -// idem := idempotency.New(conn) -// app.Post("/orders", idem.Check(), createOrderHandler) -// -// # Behavior +// The middleware uses the X-Idempotency request header ([constants.IdempotencyKey]) +// combined with the tenant ID (from tenant-manager context via +// [tmcore.GetTenantIDContext]) to form a composite Redis key: // -// - GET/OPTIONS requests pass through (idempotency is irrelevant for reads). -// - Absent header: request proceeds normally (idempotency is opt-in). -// - Header exceeds MaxKeyLength (default 256): request rejected with 400. -// - Duplicate key: cached response replayed with X-Idempotency-Replayed: true. -// - Redis failure: request proceeds (fail-open for availability). -// - Handler success: response cached; handler failure: key deleted (client can retry). +// : // -// # Redis key namespace convention +// When a tenant ID is present, keys are scoped per-tenant to prevent +// cross-tenant collisions. When no tenant is in context (e.g., non-tenant-scoped +// routes), the key becomes : in the global namespace. // -// Keys follow the pattern: : -// with a companion response key at ::response. +// A companion response key at ::response stores +// the cached response body and headers for replay. // -// The default prefix is "idempotency:" and can be overridden via WithKeyPrefix. +// The default prefix is "idempotency:" and can be overridden via [WithKeyPrefix]. // This namespacing convention is consistent with other lib-commons packages that // use Redis (e.g., rate limiting uses "ratelimit::..."). Per-tenant // isolation is enforced by embedding the tenant ID into the key rather than @@ -44,7 +29,44 @@ // implementation topology-agnostic (standalone, sentinel, and cluster all behave // identically with this approach). // +// # Quick start +// +// conn, err := redis.New(ctx, cfg) +// if err != nil { +// return err +// } +// idem := idempotency.New(conn) +// app.Post("/orders", idem.Check(), createOrderHandler) +// +// # Behavior branches +// +// The [Middleware.Check] handler evaluates requests through the following +// branches in order: +// +// - GET, HEAD, and OPTIONS requests pass through unconditionally — idempotency +// is not enforced for safe/idempotent HTTP methods. +// - Absent X-Idempotency header: request proceeds normally (idempotency is +// opt-in per request). +// - Header exceeds [WithMaxKeyLength] (default 256): request is passed to the +// configured [WithRejectedHandler]. When no custom handler is set, a 400 JSON +// response with code "VALIDATION_ERROR" is returned. +// - Redis unavailable (GetClient, SetNX, or Get failures): request proceeds +// without idempotency enforcement (fail-open), logged at WARN level. +// - Duplicate key with cached response: the original response is replayed +// faithfully — status code, headers (including Location, ETag, Set-Cookie), +// content type, and body — with [constants.IdempotencyReplayed] set to "true". +// - Duplicate key still in "processing" state (in-flight): 409 Conflict with +// code "IDEMPOTENCY_CONFLICT" is returned. +// - Duplicate key in "complete" state but no cached body (e.g., body exceeded +// [WithMaxBodyCache]): 200 OK with code "IDEMPOTENT" and detail "request +// already processed" is returned. +// - Handler success: response (status, headers, body) is cached via a Redis +// pipeline and the key is marked "complete". +// - Handler failure: both the lock key and response key are deleted so the +// client can retry with the same idempotency key. +// // # Nil safety // -// A nil *Middleware returns a pass-through handler from Check(). +// [New] returns nil when conn is nil. A nil [*Middleware] returns a pass-through +// handler from [Middleware.Check]. package idempotency diff --git a/commons/net/http/idempotency/idempotency.go b/commons/net/http/idempotency/idempotency.go index b4714e2d..6a2ff93c 100644 --- a/commons/net/http/idempotency/idempotency.go +++ b/commons/net/http/idempotency/idempotency.go @@ -21,10 +21,15 @@ const ( ) // cachedResponse stores the full HTTP response for idempotent replay. +// Body is stored as raw bytes (base64-encoded in JSON) so that binary and +// non-UTF-8 payloads survive a marshal/unmarshal round-trip. Headers preserves +// response headers that must be faithfully replayed (e.g., Location, ETag, +// Set-Cookie). type cachedResponse struct { - StatusCode int `json:"status_code"` - ContentType string `json:"content_type"` - Body string `json:"body"` + StatusCode int `json:"status_code"` + ContentType string `json:"content_type"` + Body []byte `json:"body"` + Headers map[string][]string `json:"headers,omitempty"` } // Option configures the idempotency middleware. @@ -219,17 +224,45 @@ func (m *Middleware) handleDuplicate( key, responseKey string, ) error { // Read the current key value to distinguish in-flight from completed. - keyValue, _ := client.Get(ctx, key).Result() + keyValue, keyErr := client.Get(ctx, key).Result() + if keyErr != nil && keyErr != redis.Nil { + // Unexpected Redis error (timeout, connection failure) — fail open. + m.logger.Log(ctx, log.LevelWarn, + "idempotency: failed to read key state, failing open", + log.String("key", key), log.Err(keyErr), + ) + + return c.Next() + } // Try to replay the cached response (true idempotency). cached, cacheErr := client.Get(ctx, responseKey).Result() - if cacheErr == nil && cached != "" { + + switch { + case cacheErr != nil && cacheErr != redis.Nil: + // Unexpected Redis error reading cached response — fail open. + m.logger.Log(ctx, log.LevelWarn, + "idempotency: failed to read cached response, failing open", + log.String("key", responseKey), log.Err(cacheErr), + ) + + return c.Next() + case cacheErr == nil && cached != "": var resp cachedResponse if unmarshalErr := json.Unmarshal([]byte(cached), &resp); unmarshalErr == nil { + // Replay persisted headers first so the caller sees + // Location, ETag, Set-Cookie, etc. exactly as sent originally. + for name, values := range resp.Headers { + for _, v := range values { + c.Set(name, v) + } + } + c.Set(chttp.IdempotencyReplayed, "true") c.Set("Content-Type", resp.ContentType) - return c.Status(resp.StatusCode).SendString(resp.Body) + // Send (not SendString) preserves binary/non-UTF-8 bodies. + return c.Status(resp.StatusCode).Send(resp.Body) } } @@ -268,10 +301,25 @@ func (m *Middleware) saveResult( pipe := client.Pipeline() if len(body) <= m.maxBodyCache { + // Capture response headers for faithful replay. + headers := make(map[string][]string) + c.Response().Header.VisitAll(func(key, value []byte) { + name := string(key) + // Skip headers managed by the middleware itself and + // transfer-encoding / content-length which Fiber sets on send. + switch name { + case "Content-Type", "Content-Length", "Transfer-Encoding", + chttp.IdempotencyReplayed: + return + } + headers[name] = append(headers[name], string(value)) + }) + resp := cachedResponse{ StatusCode: c.Response().StatusCode(), ContentType: string(c.Response().Header.ContentType()), - Body: string(body), + Body: body, + Headers: headers, } if data, marshalErr := json.Marshal(resp); marshalErr == nil { diff --git a/commons/systemplane/bootstrap/backend_test.go b/commons/systemplane/bootstrap/backend_test.go index 05b907a3..8b3b2430 100644 --- a/commons/systemplane/bootstrap/backend_test.go +++ b/commons/systemplane/bootstrap/backend_test.go @@ -408,7 +408,7 @@ func TestResetBackendFactories_ClearsRegistrations(t *testing.T) { RecordInitError(fmt.Errorf("simulated init error")) factories, initErrors := backendRegistry.snapshot() - assert.Len(t, factories, 1) + assert.Contains(t, factories, domain.BackendPostgres, "expected the registered factory to be present") assert.NotEmpty(t, initErrors) // Reset and verify. diff --git a/commons/systemplane/catalog/validate.go b/commons/systemplane/catalog/validate.go index 9caeb127..ac3984b7 100644 --- a/commons/systemplane/catalog/validate.go +++ b/commons/systemplane/catalog/validate.go @@ -115,11 +115,21 @@ func ValidateKeyDefs(productDefs []domain.KeyDef, catalogKeys ...[]SharedKey) [] // that accepts filtering options. Use WithIgnoreFields and WithKnownDeviation // to suppress expected mismatches in product catalog tests. func ValidateKeyDefsWithOptions(productDefs []domain.KeyDef, catalogKeys [][]SharedKey, opts []ValidateOption) []Mismatch { - keyIndex, envIndex := buildCatalogIndexes(catalogKeys...) + keyIndex, envIndex, envConflicts := buildCatalogIndexes(catalogKeys...) vc := newValidateConfig(opts) var mismatches []Mismatch + for _, conflict := range envConflicts { + mismatches = append(mismatches, Mismatch{ + CatalogKey: conflict.ConflictKey, + ProductKey: conflict.ExistingKey, + Field: "EnvVar", + CatalogValue: conflict.ExistingKey, + ProductValue: conflict.ConflictKey, + }) + } + for _, pd := range productDefs { sk, found, matchedByEnv := resolveSharedKey(pd, keyIndex, envIndex) @@ -147,20 +157,45 @@ func ValidateKeyDefsWithOptions(productDefs []domain.KeyDef, catalogKeys [][]Sha return mismatches } -func buildCatalogIndexes(catalogKeys ...[]SharedKey) (map[string]SharedKey, map[string]SharedKey) { +func buildCatalogIndexes(catalogKeys ...[]SharedKey) (map[string]SharedKey, map[string]SharedKey, []EnvVarConflict) { keyIndex := make(map[string]SharedKey) envIndex := make(map[string]SharedKey) + var conflicts []EnvVarConflict + for _, slice := range catalogKeys { for _, sk := range slice { keyIndex[sk.Key] = sk for _, envVar := range allowedEnvVars(sk) { + if existing, duplicate := envIndex[envVar]; duplicate { + conflicts = append(conflicts, EnvVarConflict{ + EnvVar: envVar, + ExistingKey: existing.Key, + ConflictKey: sk.Key, + }) + + continue + } + envIndex[envVar] = sk } } } - return keyIndex, envIndex + return keyIndex, envIndex, conflicts +} + +// EnvVarConflict describes a duplicate environment variable mapping detected +// during catalog index construction. Two SharedKey entries map the same env var. +type EnvVarConflict struct { + EnvVar string // the duplicated environment variable + ExistingKey string // the SharedKey.Key that was registered first + ConflictKey string // the SharedKey.Key that attempted to overwrite +} + +// String returns a human-readable description of the conflict. +func (c EnvVarConflict) String() string { + return fmt.Sprintf("duplicate env var %q: mapped by both %q and %q", c.EnvVar, c.ExistingKey, c.ConflictKey) } func resolveSharedKey(pd domain.KeyDef, keyIndex map[string]SharedKey, envIndex map[string]SharedKey) (SharedKey, bool, bool) { diff --git a/commons/systemplane/domain/setting_helpers.go b/commons/systemplane/domain/setting_helpers.go index 62c10573..43f5730b 100644 --- a/commons/systemplane/domain/setting_helpers.go +++ b/commons/systemplane/domain/setting_helpers.go @@ -9,13 +9,13 @@ func SnapSettingString(snap *Snapshot, tenantID, key string, fallback string) st return fallback } - if raw, ok := snap.GetTenantSetting(tenantID, key); ok { + if raw, ok := snap.GetTenantSetting(tenantID, key); ok && raw.Value != nil { if value, converted := tryCoerceString(raw.Value); converted { return value } } - if raw, ok := snap.GetGlobalSetting(key); ok { + if raw, ok := snap.GetGlobalSetting(key); ok && raw.Value != nil { if value, converted := tryCoerceString(raw.Value); converted { return value } @@ -31,13 +31,13 @@ func SnapSettingInt(snap *Snapshot, tenantID, key string, fallback int) int { return fallback } - if raw, ok := snap.GetTenantSetting(tenantID, key); ok { + if raw, ok := snap.GetTenantSetting(tenantID, key); ok && raw.Value != nil { if value, converted := tryCoerceInt(raw.Value); converted { return value } } - if raw, ok := snap.GetGlobalSetting(key); ok { + if raw, ok := snap.GetGlobalSetting(key); ok && raw.Value != nil { if value, converted := tryCoerceInt(raw.Value); converted { return value } @@ -53,13 +53,13 @@ func SnapSettingBool(snap *Snapshot, tenantID, key string, fallback bool) bool { return fallback } - if raw, ok := snap.GetTenantSetting(tenantID, key); ok { + if raw, ok := snap.GetTenantSetting(tenantID, key); ok && raw.Value != nil { if value, converted := tryCoerceBool(raw.Value); converted { return value } } - if raw, ok := snap.GetGlobalSetting(key); ok { + if raw, ok := snap.GetGlobalSetting(key); ok && raw.Value != nil { if value, converted := tryCoerceBool(raw.Value); converted { return value } diff --git a/commons/systemplane/service/supervisor.go b/commons/systemplane/service/supervisor.go index 5c07434e..cf5fa844 100644 --- a/commons/systemplane/service/supervisor.go +++ b/commons/systemplane/service/supervisor.go @@ -212,14 +212,16 @@ func (supervisor *defaultSupervisor) Reload(ctx context.Context, reason string, supervisor.mu.Lock() defer supervisor.mu.Unlock() + currentState := supervisor.state.Load() + var prevSnap *domain.Snapshot - if st := supervisor.state.Load(); st != nil { - prevSnap = &st.snapshot + if currentState != nil { + prevSnap = ¤tState.snapshot } tenantIDs := mergeUniqueTenantIDs(cachedTenantIDs(prevSnap), extraTenantIDs) - build, err := supervisor.prepareReloadBuild(ctx, tenantIDs) + build, err := supervisor.prepareReloadBuild(ctx, tenantIDs, currentState) if err != nil { libOpentelemetry.HandleSpanError(span, "build runtime bundle", err) return err diff --git a/commons/systemplane/service/supervisor_helpers.go b/commons/systemplane/service/supervisor_helpers.go index 32df7b5f..e1e170eb 100644 --- a/commons/systemplane/service/supervisor_helpers.go +++ b/commons/systemplane/service/supervisor_helpers.go @@ -147,7 +147,7 @@ func (supervisor *defaultSupervisor) buildBundle( // AdoptResourcesFrom, which runs AFTER the atomic state swap. At that point // Current() already returns the new candidate, so concurrent readers never // observe the mutation. -func (supervisor *defaultSupervisor) prepareReloadBuild(ctx context.Context, tenantIDs []string) (reloadBuild, error) { +func (supervisor *defaultSupervisor) prepareReloadBuild(ctx context.Context, tenantIDs []string, currentState *supervisorState) (reloadBuild, error) { snap, err := supervisor.builder.BuildFull(ctx, tenantIDs...) if err != nil { return reloadBuild{}, fmt.Errorf("reload: %w: %w", domain.ErrSnapshotBuildFailed, err) @@ -157,9 +157,9 @@ func (supervisor *defaultSupervisor) prepareReloadBuild(ctx context.Context, ten var previousBundle domain.RuntimeBundle - if st := supervisor.state.Load(); st != nil { - prevSnap = &st.snapshot - previousBundle = st.bundle + if currentState != nil { + prevSnap = ¤tState.snapshot + previousBundle = currentState.bundle } candidate, strategy, err := supervisor.buildBundle(ctx, snap, previousBundle, prevSnap) diff --git a/commons/webhook/deliverer.go b/commons/webhook/deliverer.go index bb2374df..9747dab9 100644 --- a/commons/webhook/deliverer.go +++ b/commons/webhook/deliverer.go @@ -26,6 +26,21 @@ import ( "github.com/LerianStudio/lib-commons/v4/commons/runtime" ) +// SignatureVersion controls the HMAC signing format used for X-Webhook-Signature. +type SignatureVersion int + +const ( + // SignatureV0 produces legacy payload-only signatures: "sha256=". + // This is the default for backward compatibility with existing consumers. + SignatureV0 SignatureVersion = iota + + // SignatureV1 produces versioned timestamp-bound signatures: + // "v1,sha256=.))>". + // The timestamp is included in the HMAC input to prevent replay attacks. + // Receivers must verify freshness (e.g., reject timestamps older than 5 minutes). + SignatureV1 +) + // Defaults for Deliverer configuration. const ( defaultMaxConcurrency = 20 @@ -51,6 +66,7 @@ type Deliverer struct { decryptor SecretDecryptor maxConc int maxRetries int + sigVersion SignatureVersion } // Option configures a Deliverer at construction time. @@ -128,6 +144,20 @@ func WithSecretDecryptor(fn SecretDecryptor) Option { } } +// WithSignatureVersion selects the HMAC signing format for X-Webhook-Signature. +// The default is SignatureV0 (payload-only) for backward compatibility. +// Use SignatureV1 to include the event timestamp in the HMAC input and produce +// a versioned signature string that enables replay protection. +// +// Migration path: switch to SignatureV1 only after all consumers have been +// updated to verify the "v1,sha256=..." format. See VerifySignatureV1 for +// the receiver-side verification logic. +func WithSignatureVersion(v SignatureVersion) Option { + return func(d *Deliverer) { + d.sigVersion = v + } +} + // defaultHTTPClient creates an http.Client optimized for webhook delivery. // Connection pooling avoids TCP+TLS handshake overhead on repeated deliveries // to the same endpoint — critical at scale where hundreds of webhooks per @@ -202,6 +232,8 @@ func (d *Deliverer) Deliver(ctx context.Context, event *Event) error { return fmt.Errorf("webhook: list endpoints: %w", err) } + // Defensive filter: EndpointLister contract guarantees active-only, + // but guard against faulty implementations. active := filterActive(endpoints) if len(active) == 0 { d.log(ctx, log.LevelDebug, "no active endpoints for event", @@ -239,6 +271,8 @@ func (d *Deliverer) DeliverWithResults(ctx context.Context, event *Event) []Deli }} } + // Defensive filter: EndpointLister contract guarantees active-only, + // but guard against faulty implementations. active := filterActive(endpoints) if len(active) == 0 { return nil @@ -335,7 +369,7 @@ func (d *Deliverer) deliverToEndpoint( result := DeliveryResult{EndpointID: ep.ID} // --- SSRF validation + DNS pinning (single lookup, eliminates TOCTOU) --- - pinnedURL, originalHost, ssrfErr := resolveAndValidateIP(ctx, ep.URL) + pinnedURL, originalAuthority, sniHostname, ssrfErr := resolveAndValidateIP(ctx, ep.URL) if ssrfErr != nil { span.RecordError(ssrfErr) span.SetStatus(codes.Error, "SSRF blocked") @@ -379,7 +413,7 @@ func (d *Deliverer) deliverToEndpoint( } } - statusCode, err := d.doHTTP(ctx, pinnedURL, originalHost, event, secret) + statusCode, err := d.doHTTP(ctx, pinnedURL, originalAuthority, sniHostname, event, secret) result.StatusCode = statusCode if err != nil { @@ -408,7 +442,7 @@ func (d *Deliverer) deliverToEndpoint( // Non-retryable client errors — break immediately (except 429 Too Many Requests). if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError && statusCode != http.StatusTooManyRequests { result.Error = fmt.Errorf("webhook: non-retryable status %d", statusCode) - d.recordMetrics(ctx, ep.ID, false, statusCode, attempt+1) + d.recordMetrics(ctx, ep.ID, false, statusCode, result.Attempts) return result } @@ -438,10 +472,16 @@ func (d *Deliverer) deliverToEndpoint( // doHTTP builds and executes a single HTTP request to the (possibly pinned) URL. // Returns the status code and any transport-level error. +// +// originalAuthority is the full host:port authority from the original URL, used +// as the HTTP Host header to preserve explicit non-default ports. +// sniHostname is the bare hostname (port stripped), used for TLS SNI and +// certificate verification when the URL has been rewritten to a pinned IP. func (d *Deliverer) doHTTP( ctx context.Context, pinnedURL string, - originalHost string, + originalAuthority string, + sniHostname string, event *Event, secret string, ) (int, error) { @@ -455,21 +495,22 @@ func (d *Deliverer) doHTTP( req.Header.Set("X-Webhook-Timestamp", strconv.FormatInt(event.Timestamp, 10)) // When the URL was rewritten to use the pinned IP, set the Host header - // for virtual hosting and use TLSClientConfig.ServerName so TLS SNI and - // certificate verification use the original hostname (not the IP). - if originalHost != "" { - req.Host = originalHost + // with the original authority (host:port) for virtual hosting, and use + // TLSClientConfig.ServerName with the bare hostname for TLS SNI and + // certificate verification (not the IP). + if originalAuthority != "" { + req.Host = originalAuthority } if secret != "" { - sig := computeHMAC(event.Payload, secret) - req.Header.Set("X-Webhook-Signature", "sha256="+sig) + sig := d.computeSignature(event.Payload, event.Timestamp, secret) + req.Header.Set("X-Webhook-Signature", sig) } client := d.client - if originalHost != "" && strings.HasPrefix(pinnedURL, "https://") { - client = d.httpsClientForPinnedIP(originalHost) + if sniHostname != "" && strings.HasPrefix(pinnedURL, "https://") { + client = d.httpsClientForPinnedIP(sniHostname) } resp, err := client.Do(req) @@ -506,18 +547,25 @@ func (d *Deliverer) resolveSecret(raw string) (string, error) { return plaintext, nil } +// computeSignature dispatches to the appropriate HMAC format based on the +// configured signature version. +func (d *Deliverer) computeSignature(payload []byte, timestamp int64, secret string) string { + switch d.sigVersion { + case SignatureV1: + return computeHMACv1(payload, timestamp, secret) + default: + return "sha256=" + computeHMAC(payload, secret) + } +} + // computeHMAC returns the hex-encoded HMAC-SHA256 of payload using the given secret. +// This is the legacy (v0) format that signs the raw payload only. // -// Design note — timestamp not included in signature (by intent): +// Design note — timestamp not included in v0 signature (by intent): // The signature covers the raw payload only, not the X-Webhook-Timestamp value. -// Some industry integrations (e.g., Stripe) sign "timestamp.payload" to bind the -// timestamp to the signature and prevent replay attacks. We intentionally do not -// do this because: (a) changing the input format would silently break all existing -// consumers who already verify "sha256=HMAC(payload)" and (b) replay protection -// at the application layer (e.g., idempotency keys, short timestamp windows) is -// the responsibility of the receiving service. Receivers who need replay protection -// should validate that X-Webhook-Timestamp is within an acceptable window (e.g., -// ±5 minutes) independently of the signature check. +// This format is maintained for backward compatibility. New deployments should +// prefer SignatureV1 (via WithSignatureVersion) which binds the timestamp into +// the HMAC input to enable replay protection. func computeHMAC(payload []byte, secret string) string { mac := hmac.New(sha256.New, []byte(secret)) mac.Write(payload) @@ -525,6 +573,91 @@ func computeHMAC(payload []byte, secret string) string { return hex.EncodeToString(mac.Sum(nil)) } +// computeHMACv1 returns a versioned signature string: +// +// "v1,sha256=.", secret))>" +// +// The version prefix "v1:" followed by the decimal timestamp and a dot separator +// are prepended to the payload before computing the HMAC, binding the timestamp +// to the signature. Receivers must parse the "v1," prefix, extract the timestamp +// from the X-Webhook-Timestamp header, reconstruct the signing input, and verify +// the HMAC before accepting the webhook. +func computeHMACv1(payload []byte, timestamp int64, secret string) string { + ts := strconv.FormatInt(timestamp, 10) + + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write([]byte("v1:")) + mac.Write([]byte(ts)) + mac.Write([]byte(".")) + mac.Write(payload) + + return "v1,sha256=" + hex.EncodeToString(mac.Sum(nil)) +} + +// VerifySignature verifies a webhook signature by auto-detecting the version +// from the signature string format. Returns nil on valid signature, or an error +// describing the mismatch. +// +// Supported formats: +// - "sha256=" — v0 (payload-only) +// - "v1,sha256=" — v1 (timestamp-bound) +// +// For v1 signatures, the timestamp parameter is required to reconstruct the +// signing input. For v0 signatures, the timestamp is ignored. +func VerifySignature(payload []byte, timestamp int64, secret, signature string) error { + switch { + case strings.HasPrefix(signature, "v1,"): + expected := computeHMACv1(payload, timestamp, secret) + if !hmac.Equal([]byte(signature), []byte(expected)) { + return errors.New("webhook: v1 signature mismatch") + } + + return nil + + case strings.HasPrefix(signature, "sha256="): + expected := "sha256=" + computeHMAC(payload, secret) + if !hmac.Equal([]byte(signature), []byte(expected)) { + return errors.New("webhook: v0 signature mismatch") + } + + return nil + + default: + return errors.New("webhook: unrecognized signature format") + } +} + +// VerifySignatureWithFreshness verifies a v1 webhook signature and additionally +// checks that the timestamp is within the given tolerance window from now. +// This provides replay protection: even if an attacker captures a valid +// payload+signature pair, it becomes invalid after the tolerance window expires. +// +// For v0 ("sha256=...") signatures, freshness cannot be enforced because the +// timestamp is not covered by the HMAC. Callers receiving v0 signatures should +// use VerifySignature and implement replay protection independently (e.g., +// idempotency keys or event-ID tracking). +func VerifySignatureWithFreshness(payload []byte, timestamp int64, secret, signature string, tolerance time.Duration) error { + if err := VerifySignature(payload, timestamp, secret, signature); err != nil { + return err + } + + // Freshness check only applies to v1 where the timestamp is signed. + if strings.HasPrefix(signature, "v1,") { + eventTime := time.Unix(timestamp, 0) + delta := time.Since(eventTime) + + if delta < 0 { + delta = -delta + } + + if delta > tolerance { + return fmt.Errorf("webhook: timestamp outside tolerance window (%s > %s)", delta.Truncate(time.Second), tolerance) + } + } + + return nil +} + // filterActive returns only endpoints where Active is true. func filterActive(endpoints []Endpoint) []Endpoint { active := make([]Endpoint, 0, len(endpoints)) @@ -605,17 +738,20 @@ func (d *Deliverer) httpsClientForPinnedIP(originalHost string) *http.Client { return &clone } -// sanitizeURL strips query parameters from a URL before logging to prevent -// credential leakage. Webhook URLs may carry tokens in query params -// (e.g., ?token=..., ?api_key=...) that must not appear in log output. -// On parse failure the raw string is returned unchanged so no log line is lost. +// sanitizeURL strips query parameters and userinfo from a URL before logging +// to prevent credential leakage. Webhook URLs may carry tokens in query params +// (e.g., ?token=..., ?api_key=...) or credentials in the userinfo component +// (e.g., https://user:pass@host/...) that must not appear in log output. +// On parse failure a safe placeholder is returned instead of the raw input +// to avoid leaking credentials embedded in malformed URLs. func sanitizeURL(rawURL string) string { u, err := url.Parse(rawURL) if err != nil { - return rawURL + return "[invalid-url]" } u.RawQuery = "" + u.User = nil return u.String() } diff --git a/commons/webhook/deliverer_test.go b/commons/webhook/deliverer_test.go index ef4fe7dc..43ec3fd8 100644 --- a/commons/webhook/deliverer_test.go +++ b/commons/webhook/deliverer_test.go @@ -14,6 +14,7 @@ import ( "net/http" "net/http/httptest" "strconv" + "strings" "sync" "sync/atomic" "testing" @@ -900,3 +901,222 @@ func TestResolveSecret(t *testing.T) { }) } } + +// --------------------------------------------------------------------------- +// sanitizeURL — credential redaction +// --------------------------------------------------------------------------- + +func TestSanitizeURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw string + want string + }{ + { + name: "strips query params", + raw: "https://example.com/webhook?token=secret&api_key=abc", + want: "https://example.com/webhook", + }, + { + name: "strips userinfo", + raw: "https://user:pass@example.com/webhook", + want: "https://example.com/webhook", + }, + { + name: "strips both userinfo and query", + raw: "https://admin:hunter2@example.com/hook?key=val", + want: "https://example.com/hook", + }, + { + name: "no credentials passes through", + raw: "https://example.com/webhook", + want: "https://example.com/webhook", + }, + { + name: "invalid URL returns placeholder", + raw: "://bad\x7f", + want: "[invalid-url]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := sanitizeURL(tt.raw) + assert.Equal(t, tt.want, got) + }) + } +} + +// --------------------------------------------------------------------------- +// Versioned signatures — v1 HMAC with timestamp binding +// --------------------------------------------------------------------------- + +func TestComputeHMACv1(t *testing.T) { + t.Parallel() + + payload := []byte(`{"id":"123"}`) + secret := "test-secret" + timestamp := int64(1700000000) + + sig := computeHMACv1(payload, timestamp, secret) + assert.True(t, strings.HasPrefix(sig, "v1,sha256="), + "v1 signature must start with 'v1,sha256='") + + // Verify it is deterministic. + sig2 := computeHMACv1(payload, timestamp, secret) + assert.Equal(t, sig, sig2, "same input must produce same signature") + + // Different timestamp must produce different signature. + sig3 := computeHMACv1(payload, timestamp+1, secret) + assert.NotEqual(t, sig, sig3, "different timestamp must produce different signature") +} + +func TestComputeSignature_V0Default(t *testing.T) { + t.Parallel() + + d := &Deliverer{sigVersion: SignatureV0} + payload := []byte(`{"id":"123"}`) + timestamp := int64(1700000000) + secret := "test-secret" + + sig := d.computeSignature(payload, timestamp, secret) + assert.True(t, strings.HasPrefix(sig, "sha256="), + "v0 signature must start with 'sha256='") + assert.False(t, strings.HasPrefix(sig, "v1,"), + "v0 signature must NOT have v1 prefix") +} + +func TestComputeSignature_V1(t *testing.T) { + t.Parallel() + + d := &Deliverer{sigVersion: SignatureV1} + payload := []byte(`{"id":"123"}`) + timestamp := int64(1700000000) + secret := "test-secret" + + sig := d.computeSignature(payload, timestamp, secret) + assert.True(t, strings.HasPrefix(sig, "v1,sha256="), + "v1 signature must start with 'v1,sha256='") +} + +func TestVerifySignature(t *testing.T) { + t.Parallel() + + payload := []byte(`{"id":"123"}`) + secret := "test-secret" + timestamp := int64(1700000000) + + tests := []struct { + name string + sig string + wantErr string + }{ + { + name: "valid v0 signature", + sig: "sha256=" + computeHMAC(payload, secret), + }, + { + name: "valid v1 signature", + sig: computeHMACv1(payload, timestamp, secret), + }, + { + name: "invalid v0 signature", + sig: "sha256=deadbeef", + wantErr: "v0 signature mismatch", + }, + { + name: "invalid v1 signature", + sig: "v1,sha256=deadbeef", + wantErr: "v1 signature mismatch", + }, + { + name: "unrecognized format", + sig: "v99,md5=abc", + wantErr: "unrecognized signature format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := VerifySignature(payload, timestamp, secret, tt.sig) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestVerifySignatureWithFreshness(t *testing.T) { + t.Parallel() + + payload := []byte(`{"id":"123"}`) + secret := "test-secret" + now := time.Now().Unix() + + // Fresh v1 signature (timestamp = now). + freshSig := computeHMACv1(payload, now, secret) + + err := VerifySignatureWithFreshness(payload, now, secret, freshSig, 5*time.Minute) + assert.NoError(t, err, "fresh v1 signature within tolerance must pass") + + // Stale v1 signature (timestamp = 1 hour ago). + staleTS := now - 3600 + staleSig := computeHMACv1(payload, staleTS, secret) + + err = VerifySignatureWithFreshness(payload, staleTS, secret, staleSig, 5*time.Minute) + require.Error(t, err) + assert.Contains(t, err.Error(), "outside tolerance window") + + // V0 signature — freshness check is skipped (timestamp not signed). + v0Sig := "sha256=" + computeHMAC(payload, secret) + + err = VerifySignatureWithFreshness(payload, staleTS, secret, v0Sig, 5*time.Minute) + assert.NoError(t, err, "v0 signature skips freshness check") +} + +func TestDeliver_V1Signature_EndToEnd(t *testing.T) { + t.Parallel() + + secret := "v1-secret" + event := newTestEvent() + + var gotSig string + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSig = r.Header.Get("X-Webhook-Signature") + w.WriteHeader(http.StatusOK) + })) + + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-v1", URL: pubURL, Secret: secret, Active: true}, + }, + } + + d := NewDeliverer(lister, + WithMaxRetries(0), + WithSignatureVersion(SignatureV1), + WithHTTPClient(ssrfBypassClient(srv.Listener.Addr().String())), + ) + + results := d.DeliverWithResults(context.Background(), event) + require.Len(t, results, 1) + require.True(t, results[0].Success) + + // The signature should be v1 format. + assert.True(t, strings.HasPrefix(gotSig, "v1,sha256="), + "v1 deliverer must produce v1 signature format, got: %s", gotSig) + + // Verify the signature is correct. + err := VerifySignature(event.Payload, event.Timestamp, secret, gotSig) + assert.NoError(t, err, "v1 signature from deliverer must verify correctly") +} diff --git a/commons/webhook/doc.go b/commons/webhook/doc.go index ded26825..06de7ba8 100644 --- a/commons/webhook/doc.go +++ b/commons/webhook/doc.go @@ -6,8 +6,8 @@ // - Two-layer SSRF protection: pre-resolution IP validation + DNS-pinned delivery // - DNS rebinding prevention via resolved IP pinning // - Blocks private, loopback, and link-local IP ranges -// - HMAC-SHA256 signature in X-Webhook-Signature header (sha256=HEX) -// - URL query parameters stripped from log output to prevent credential leakage +// - HMAC-SHA256 signature in X-Webhook-Signature header (versioned format) +// - URL query parameters and userinfo stripped from log output to prevent credential leakage // // # Delivery model // @@ -15,16 +15,46 @@ // - Exponential backoff with jitter (1s, 2s, 4s, ...) // - Per-endpoint retry with configurable max attempts (default: 3) // -// # HMAC signature scope +// # HMAC signature versions // -// The X-Webhook-Signature header carries HMAC-SHA256 computed over the raw -// payload bytes only. X-Webhook-Timestamp is sent as a separate informational -// header but is NOT covered by the HMAC — an attacker who captures a valid -// payload+signature pair can replay it with a fresh timestamp. +// The X-Webhook-Signature header supports two formats, selectable via +// WithSignatureVersion at Deliverer construction time: // -// Receivers who need replay protection must implement it independently, for -// example by tracking event IDs or embedding a nonce in the payload itself. -// Timestamp-window checks alone are insufficient because the timestamp is -// unsigned. Including the timestamp in the HMAC would be a breaking change -// for existing consumers. +// ## v0 (default — backward compatible) +// +// Format: "sha256=" +// +// The HMAC covers the raw payload bytes only. X-Webhook-Timestamp is sent as +// a separate informational header but is NOT covered by the HMAC. An attacker +// who captures a valid payload+signature pair can replay it. Receivers must +// implement replay protection independently (e.g., event-ID tracking, +// idempotency keys). +// +// ## v1 (recommended for new deployments) +// +// Format: "v1,sha256=.", secret))>" +// +// The HMAC input includes a version prefix, the decimal Unix-epoch timestamp +// from X-Webhook-Timestamp, and the raw payload — binding the timestamp to +// the signature. Receivers must: +// 1. Parse the "v1," prefix from X-Webhook-Signature. +// 2. Read X-Webhook-Timestamp and reconstruct the signing input as +// "v1:.". +// 3. Compute HMAC-SHA256 with the shared secret and compare (constant-time). +// 4. Reject timestamps outside an acceptable clock-skew window (e.g., +/- 5 min) +// or track event IDs / nonces to prevent replay. +// +// Use VerifySignature or VerifySignatureWithFreshness for receiver-side +// verification — both auto-detect the version from the signature string. +// +// # Migration from v0 to v1 +// +// Because this is a library used by multiple services, the default remains v0 +// to avoid breaking existing consumers. To migrate: +// 1. Update all webhook receivers to accept both v0 and v1 formats (use +// VerifySignature which auto-detects the version). +// 2. Once all receivers are updated, switch senders to v1 by constructing +// the Deliverer with WithSignatureVersion(SignatureV1). +// 3. After a transition period, receivers may optionally reject v0 signatures +// to enforce replay protection. package webhook diff --git a/commons/webhook/ssrf.go b/commons/webhook/ssrf.go index 46637131..52649cfb 100644 --- a/commons/webhook/ssrf.go +++ b/commons/webhook/ssrf.go @@ -56,35 +56,37 @@ var additionalBlockedRanges = []*net.IPNet{ // the pinned IP to differ from the validated one. // // On success it returns: -// - pinnedURL — original URL with the hostname replaced by the first resolved IP. -// - originalHost — the original hostname, for use as the HTTP Host header (TLS SNI). +// - pinnedURL — original URL with the hostname replaced by the first resolved IP. +// - originalAuthority — the original host:port authority (u.Host), suitable for the +// HTTP Host header so explicit non-default ports are preserved. +// - sniHostname — the bare hostname (u.Hostname(), port stripped), suitable for +// TLS SNI / certificate verification. // // DNS lookup failures are fail-closed: if the hostname cannot be resolved, the // URL is rejected. When no resolved IP can be parsed from the DNS response the // URL is considered unresolvable and an error is returned. -func resolveAndValidateIP(ctx context.Context, rawURL string) (pinnedURL string, originalHost string, err error) { +func resolveAndValidateIP(ctx context.Context, rawURL string) (pinnedURL, originalAuthority, sniHostname string, err error) { u, err := url.Parse(rawURL) if err != nil { - return "", "", fmt.Errorf("%w: %w", ErrInvalidURL, err) + return "", "", "", fmt.Errorf("%w: %w", ErrInvalidURL, err) } - scheme := strings.ToLower(u.Scheme) - if scheme != "http" && scheme != "https" { - return "", "", fmt.Errorf("%w: scheme %q not allowed", ErrSSRFBlocked, scheme) + if err := validateScheme(u.Scheme); err != nil { + return "", "", "", err } host := u.Hostname() if host == "" { - return "", "", fmt.Errorf("%w: empty hostname", ErrInvalidURL) + return "", "", "", fmt.Errorf("%w: empty hostname", ErrInvalidURL) } ips, dnsErr := net.DefaultResolver.LookupHost(ctx, host) if dnsErr != nil { - return "", "", fmt.Errorf("%w: DNS lookup failed for %s: %w", ErrSSRFBlocked, host, dnsErr) + return "", "", "", fmt.Errorf("%w: DNS lookup failed for %s: %w", ErrSSRFBlocked, host, dnsErr) } if len(ips) == 0 { - return "", "", fmt.Errorf("%w: DNS returned no addresses for %s", ErrSSRFBlocked, host) + return "", "", "", fmt.Errorf("%w: DNS returned no addresses for %s", ErrSSRFBlocked, host) } var firstValidIP string @@ -96,7 +98,7 @@ func resolveAndValidateIP(ctx context.Context, rawURL string) (pinnedURL string, } if isPrivateIP(ip) { - return "", "", fmt.Errorf("%w: resolved IP %s is private/loopback", ErrSSRFBlocked, ipStr) + return "", "", "", fmt.Errorf("%w: resolved IP %s is private/loopback", ErrSSRFBlocked, ipStr) } if firstValidIP == "" { @@ -105,9 +107,13 @@ func resolveAndValidateIP(ctx context.Context, rawURL string) (pinnedURL string, } if firstValidIP == "" { - return "", "", fmt.Errorf("%w: no valid IPs resolved for %s", ErrInvalidURL, host) + return "", "", "", fmt.Errorf("%w: no valid IPs resolved for %s", ErrInvalidURL, host) } + // Preserve the original authority (host:port) for the HTTP Host header + // before rewriting u.Host to the pinned IP. + authority := u.Host + // Pin to first valid resolved IP to prevent DNS rebinding across retries. port := u.Port() @@ -121,7 +127,20 @@ func resolveAndValidateIP(ctx context.Context, rawURL string) (pinnedURL string, u.Host = firstValidIP } - return u.String(), host, nil + return u.String(), authority, host, nil +} + +// validateScheme checks that the URL scheme is http or https. All other +// schemes are rejected to prevent SSRF via exotic protocols (file://, gopher://, etc.). +// This is a pure function with no DNS or I/O dependency, making it safe for +// isolated unit testing. +func validateScheme(scheme string) error { + s := strings.ToLower(scheme) + if s != "http" && s != "https" { + return fmt.Errorf("%w: scheme %q not allowed", ErrSSRFBlocked, scheme) + } + + return nil } // isPrivateIP reports whether ip is in a private, loopback, link-local, diff --git a/commons/webhook/ssrf_test.go b/commons/webhook/ssrf_test.go index 456cb2dd..b0ea8b4c 100644 --- a/commons/webhook/ssrf_test.go +++ b/commons/webhook/ssrf_test.go @@ -4,7 +4,6 @@ package webhook import ( "context" - "errors" "net" "testing" @@ -86,7 +85,7 @@ func TestIsPrivateIP(t *testing.T) { func TestResolveAndValidateIP_InvalidURLMalformed(t *testing.T) { t.Parallel() - _, _, err := resolveAndValidateIP(context.Background(), "://missing-scheme") + _, _, _, err := resolveAndValidateIP(context.Background(), "://missing-scheme") assert.Error(t, err) assert.ErrorIs(t, err, ErrInvalidURL) } @@ -94,7 +93,7 @@ func TestResolveAndValidateIP_InvalidURLMalformed(t *testing.T) { func TestResolveAndValidateIP_EmptyHostnameHTTP(t *testing.T) { t.Parallel() - _, _, err := resolveAndValidateIP(context.Background(), "http://") + _, _, _, err := resolveAndValidateIP(context.Background(), "http://") assert.Error(t, err) assert.ErrorIs(t, err, ErrInvalidURL) assert.Contains(t, err.Error(), "empty hostname") @@ -122,35 +121,64 @@ func TestResolveAndValidateIP_UnsupportedSchemes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, _, err := resolveAndValidateIP(context.Background(), tt.url) + _, _, _, err := resolveAndValidateIP(context.Background(), tt.url) assert.Error(t, err) assert.ErrorIs(t, err, ErrSSRFBlocked) }) } } -func TestResolveAndValidateIP_AllowedSchemes(t *testing.T) { +func TestValidateScheme_AllowedSchemes(t *testing.T) { t.Parallel() - // These schemes should pass the scheme check (they may still fail DNS - // resolution, but they should NOT fail with ErrSSRFBlocked for scheme). + // These schemes should pass the scheme check — tested via the pure + // validateScheme function to avoid any DNS dependency. tests := []struct { - name string - url string + name string + scheme string + }{ + {name: "http scheme", scheme: "http"}, + {name: "https scheme", scheme: "https"}, + {name: "HTTP uppercase", scheme: "HTTP"}, + {name: "HTTPS uppercase", scheme: "HTTPS"}, + {name: "mixed case Http", scheme: "Http"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := validateScheme(tt.scheme) + assert.NoError(t, err, "scheme %q must be allowed", tt.scheme) + }) + } +} + +func TestValidateScheme_BlockedSchemes(t *testing.T) { + t.Parallel() + + // These schemes must be rejected — tested via the pure validateScheme + // function to avoid any DNS dependency. + tests := []struct { + name string + scheme string }{ - {name: "http scheme", url: "http://example.com"}, - {name: "https scheme", url: "https://example.com"}, + {name: "gopher scheme", scheme: "gopher"}, + {name: "file scheme", scheme: "file"}, + {name: "ftp scheme", scheme: "ftp"}, + {name: "javascript scheme", scheme: "javascript"}, + {name: "data scheme", scheme: "data"}, + {name: "empty scheme", scheme: ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, _, err := resolveAndValidateIP(context.Background(), tt.url) - // The error, if any, should NOT be ErrSSRFBlocked for scheme reasons. - if err != nil { - assert.False(t, errors.Is(err, ErrSSRFBlocked), "http/https should not be SSRF-blocked: %v", err) - } + err := validateScheme(tt.scheme) + require.Error(t, err) + assert.ErrorIs(t, err, ErrSSRFBlocked, + "scheme %q must be blocked", tt.scheme) }) } } @@ -177,7 +205,7 @@ func TestResolveAndValidateIP_InvalidScheme(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, _, err := resolveAndValidateIP(context.Background(), tt.url) + _, _, _, err := resolveAndValidateIP(context.Background(), tt.url) require.Error(t, err) assert.ErrorIs(t, err, ErrSSRFBlocked, "non-HTTP/HTTPS scheme must return ErrSSRFBlocked") @@ -189,7 +217,7 @@ func TestResolveAndValidateIP_InvalidScheme(t *testing.T) { func TestResolveAndValidateIP_EmptyHostname(t *testing.T) { t.Parallel() - _, _, err := resolveAndValidateIP(context.Background(), "http://") + _, _, _, err := resolveAndValidateIP(context.Background(), "http://") require.Error(t, err) assert.ErrorIs(t, err, ErrInvalidURL) assert.Contains(t, err.Error(), "empty hostname") @@ -200,7 +228,7 @@ func TestResolveAndValidateIP_EmptyHostname(t *testing.T) { func TestResolveAndValidateIP_InvalidURL(t *testing.T) { t.Parallel() - _, _, err := resolveAndValidateIP(context.Background(), "://no-scheme") + _, _, _, err := resolveAndValidateIP(context.Background(), "://no-scheme") require.Error(t, err) assert.ErrorIs(t, err, ErrInvalidURL) } @@ -223,7 +251,7 @@ func TestResolveAndValidateIP_PrivateIP(t *testing.T) { t.Skip("localhost DNS lookup failed or returned no results — skipping SSRF private-IP test") } - _, _, err := resolveAndValidateIP(context.Background(), "http://localhost") + _, _, _, err := resolveAndValidateIP(context.Background(), "http://localhost") require.Error(t, err) assert.ErrorIs(t, err, ErrSSRFBlocked, "localhost (127.0.0.1) must be blocked as a private/loopback address") From ce56bc2bf99dc44a0397408d22d4156e843f97d4 Mon Sep 17 00:00:00 2001 From: Fred Amaral Date: Mon, 30 Mar 2026 11:41:57 -0300 Subject: [PATCH 2/9] fix: reduce Enqueue cyclomatic complexity and add missing errors import - Extract validateEnqueueMessage, stampInitialEnqueue, and resolveAndValidateTenant helpers from Handler.Enqueue to bring cyclomatic complexity from 18 down to 7 - Add missing "errors" import in idempotency package for errors.Is calls --- commons/dlq/handler.go | 114 ++++++++++++-------- commons/net/http/idempotency/idempotency.go | 7 +- 2 files changed, 74 insertions(+), 47 deletions(-) diff --git a/commons/dlq/handler.go b/commons/dlq/handler.go index ac4f9c3f..0b9a4188 100644 --- a/commons/dlq/handler.go +++ b/commons/dlq/handler.go @@ -148,6 +148,50 @@ func (h *Handler) Enqueue(ctx context.Context, msg *FailedMessage) error { return ErrNilHandler } + if err := validateEnqueueMessage(msg); err != nil { + return err + } + + ctx, span := h.tracer.Start(ctx, "dlq.enqueue") + defer span.End() + + h.stampInitialEnqueue(msg) + + effectiveTenant, err := h.resolveAndValidateTenant(ctx, msg) + if err != nil { + return err + } + + data, err := json.Marshal(msg) + if err != nil { + libOtel.HandleSpanError(span, "dlq marshal failed", err) + + return fmt.Errorf("dlq: enqueue: marshal: %w", err) + } + + key := h.tenantScopedKeyForTenant(effectiveTenant, msg.Source) + + rds, err := h.conn.GetClient(ctx) + if err != nil { + libOtel.HandleSpanError(span, "dlq redis client unavailable", err) + h.logEnqueueFallback(ctx, key, msg, err) + + return fmt.Errorf("dlq: enqueue: redis client: %w", err) + } + + if pushErr := rds.RPush(ctx, key, data).Err(); pushErr != nil { + libOtel.HandleSpanError(span, "dlq rpush failed", pushErr) + h.logEnqueueFallback(ctx, key, msg, pushErr) + + return fmt.Errorf("dlq: enqueue: rpush: %w", pushErr) + } + + return nil +} + +// validateEnqueueMessage performs pre-flight validation on the message before +// any state mutation or tracing begins. +func validateEnqueueMessage(msg *FailedMessage) error { if msg == nil { return errors.New("dlq: enqueue: nil message") } @@ -156,17 +200,15 @@ func (h *Handler) Enqueue(ctx context.Context, msg *FailedMessage) error { return errors.New("dlq: enqueue: source must not be empty") } - if err := validateKeySegment("source", msg.Source); err != nil { - return err - } - - ctx, span := h.tracer.Start(ctx, "dlq.enqueue") - defer span.End() + return validateKeySegment("source", msg.Source) +} - // Only stamp CreatedAt and MaxRetries on initial enqueue (zero-valued). - // Re-enqueue paths (consumer retry-failed, not-yet-ready, prune) pass - // messages that already carry the original values; overwriting them would - // permanently lose the original failure timestamp and retry budget. +// stampInitialEnqueue sets CreatedAt, MaxRetries, and NextRetryAt on messages +// that are being enqueued for the first time (CreatedAt is zero). +// Re-enqueue paths (consumer retry-failed, not-yet-ready, prune) pass messages +// that already carry the original values; overwriting them would permanently +// lose the original failure timestamp and retry budget. +func (h *Handler) stampInitialEnqueue(msg *FailedMessage) { initialEnqueue := msg.CreatedAt.IsZero() if initialEnqueue { msg.CreatedAt = time.Now().UTC() @@ -176,6 +218,20 @@ func (h *Handler) Enqueue(ctx context.Context, msg *FailedMessage) error { msg.MaxRetries = h.maxRetries } + // Recalculate NextRetryAt only on initial enqueue. On re-enqueue the + // consumer has already incremented RetryCount and the caller is + // responsible for timing; we preserve their NextRetryAt or let the + // backoff be recalculated by the consumer path that sets RetryCount. + if initialEnqueue && msg.RetryCount < msg.MaxRetries { + msg.NextRetryAt = msg.CreatedAt.Add(backoffDuration(msg.RetryCount)) + } +} + +// resolveAndValidateTenant determines the effective tenant ID for the message by +// reconciling the message's TenantID with the tenant from context. It validates +// that they match when both are present, and validates the tenant as a safe Redis +// key segment. Returns the effective tenant ID. +func (h *Handler) resolveAndValidateTenant(ctx context.Context, msg *FailedMessage) (string, error) { ctxTenant := tmcore.GetTenantIDContext(ctx) effectiveTenant := msg.TenantID @@ -185,7 +241,7 @@ func (h *Handler) Enqueue(ctx context.Context, msg *FailedMessage) error { } if effectiveTenant != "" && ctxTenant != "" && effectiveTenant != ctxTenant { - return fmt.Errorf("dlq: enqueue: tenant mismatch between message (%s) and context (%s)", effectiveTenant, ctxTenant) + return "", fmt.Errorf("dlq: enqueue: tenant mismatch between message (%s) and context (%s)", effectiveTenant, ctxTenant) } // Validate the effective tenant before using it to construct a Redis key. @@ -194,43 +250,11 @@ func (h *Handler) Enqueue(ctx context.Context, msg *FailedMessage) error { // tenant-scoped messages into the global queue. if effectiveTenant != "" { if err := validateKeySegment("tenantID", effectiveTenant); err != nil { - return fmt.Errorf("dlq: enqueue: %w", err) + return "", fmt.Errorf("dlq: enqueue: %w", err) } } - // Recalculate NextRetryAt only on initial enqueue. On re-enqueue the - // consumer has already incremented RetryCount and the caller is - // responsible for timing; we preserve their NextRetryAt or let the - // backoff be recalculated by the consumer path that sets RetryCount. - if initialEnqueue && msg.RetryCount < msg.MaxRetries { - msg.NextRetryAt = msg.CreatedAt.Add(backoffDuration(msg.RetryCount)) - } - - data, err := json.Marshal(msg) - if err != nil { - libOtel.HandleSpanError(span, "dlq marshal failed", err) - - return fmt.Errorf("dlq: enqueue: marshal: %w", err) - } - - key := h.tenantScopedKeyForTenant(effectiveTenant, msg.Source) - - rds, err := h.conn.GetClient(ctx) - if err != nil { - libOtel.HandleSpanError(span, "dlq redis client unavailable", err) - h.logEnqueueFallback(ctx, key, msg, err) - - return fmt.Errorf("dlq: enqueue: redis client: %w", err) - } - - if pushErr := rds.RPush(ctx, key, data).Err(); pushErr != nil { - libOtel.HandleSpanError(span, "dlq rpush failed", pushErr) - h.logEnqueueFallback(ctx, key, msg, pushErr) - - return fmt.Errorf("dlq: enqueue: rpush: %w", pushErr) - } - - return nil + return effectiveTenant, nil } // logEnqueueFallback logs message metadata when Redis is unreachable. The diff --git a/commons/net/http/idempotency/idempotency.go b/commons/net/http/idempotency/idempotency.go index 6a2ff93c..45512e94 100644 --- a/commons/net/http/idempotency/idempotency.go +++ b/commons/net/http/idempotency/idempotency.go @@ -3,6 +3,7 @@ package idempotency import ( "context" "encoding/json" + "errors" "fmt" "net/http" "time" @@ -225,7 +226,7 @@ func (m *Middleware) handleDuplicate( ) error { // Read the current key value to distinguish in-flight from completed. keyValue, keyErr := client.Get(ctx, key).Result() - if keyErr != nil && keyErr != redis.Nil { + if keyErr != nil && !errors.Is(keyErr, redis.Nil) { // Unexpected Redis error (timeout, connection failure) — fail open. m.logger.Log(ctx, log.LevelWarn, "idempotency: failed to read key state, failing open", @@ -239,7 +240,7 @@ func (m *Middleware) handleDuplicate( cached, cacheErr := client.Get(ctx, responseKey).Result() switch { - case cacheErr != nil && cacheErr != redis.Nil: + case cacheErr != nil && !errors.Is(cacheErr, redis.Nil): // Unexpected Redis error reading cached response — fail open. m.logger.Log(ctx, log.LevelWarn, "idempotency: failed to read cached response, failing open", @@ -303,6 +304,7 @@ func (m *Middleware) saveResult( if len(body) <= m.maxBodyCache { // Capture response headers for faithful replay. headers := make(map[string][]string) + c.Response().Header.VisitAll(func(key, value []byte) { name := string(key) // Skip headers managed by the middleware itself and @@ -312,6 +314,7 @@ func (m *Middleware) saveResult( chttp.IdempotencyReplayed: return } + headers[name] = append(headers[name], string(value)) }) From f63d6e4af77e6ad2743ac9005bbe1e59baaed501 Mon Sep 17 00:00:00 2001 From: Fred Amaral Date: Mon, 30 Mar 2026 12:14:31 -0300 Subject: [PATCH 3/9] fix: replace deprecated Header.VisitAll with Header.All iterator staticcheck SA1019: fasthttp Header.VisitAll is deprecated in favor of the range-compatible All() iterator. --- commons/net/http/idempotency/idempotency.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/commons/net/http/idempotency/idempotency.go b/commons/net/http/idempotency/idempotency.go index 45512e94..07de7597 100644 --- a/commons/net/http/idempotency/idempotency.go +++ b/commons/net/http/idempotency/idempotency.go @@ -305,18 +305,18 @@ func (m *Middleware) saveResult( // Capture response headers for faithful replay. headers := make(map[string][]string) - c.Response().Header.VisitAll(func(key, value []byte) { + for key, value := range c.Response().Header.All() { name := string(key) // Skip headers managed by the middleware itself and // transfer-encoding / content-length which Fiber sets on send. switch name { case "Content-Type", "Content-Length", "Transfer-Encoding", chttp.IdempotencyReplayed: - return + continue } headers[name] = append(headers[name], string(value)) - }) + } resp := cachedResponse{ StatusCode: c.Response().StatusCode(), From 90b2e68875900d1bbac202966549668c047d259d Mon Sep 17 00:00:00 2001 From: Fred Amaral Date: Mon, 30 Mar 2026 12:34:02 -0300 Subject: [PATCH 4/9] fix: address CodeRabbit round-2 findings on PR #421 - idempotency: use Header.Add for multi-value replay, log unmarshal failures, rename shadowed loop variable - catalog: introduce ValidationResult to separate env var conflicts from Mismatch, mark duplicate env vars as ambiguous in index - webhook: strip URL fragments in sanitizeURL, fix godoc references to VerifySignature/VerifySignatureWithFreshness, pin TestComputeHMACv1 to independently computed expected value --- commons/net/http/idempotency/idempotency.go | 18 ++++-- commons/systemplane/catalog/validate.go | 61 +++++++++++++++---- .../catalog/validate_options_test.go | 20 +++--- commons/webhook/deliverer.go | 10 +-- commons/webhook/deliverer_test.go | 19 +++++- 5 files changed, 95 insertions(+), 33 deletions(-) diff --git a/commons/net/http/idempotency/idempotency.go b/commons/net/http/idempotency/idempotency.go index 07de7597..96a6f3f6 100644 --- a/commons/net/http/idempotency/idempotency.go +++ b/commons/net/http/idempotency/idempotency.go @@ -250,12 +250,22 @@ func (m *Middleware) handleDuplicate( return c.Next() case cacheErr == nil && cached != "": var resp cachedResponse - if unmarshalErr := json.Unmarshal([]byte(cached), &resp); unmarshalErr == nil { + if unmarshalErr := json.Unmarshal([]byte(cached), &resp); unmarshalErr != nil { + // Cache entry is corrupt or written by an incompatible version. + // Log a warning so operators can investigate, then fall through + // to the generic "already processed" response (fail-open). + m.logger.Log(ctx, log.LevelWarn, + "idempotency: failed to unmarshal cached response, falling through to generic reply", + log.String("key", responseKey), log.Err(unmarshalErr), + ) + } else { // Replay persisted headers first so the caller sees // Location, ETag, Set-Cookie, etc. exactly as sent originally. + // Use Header.Add (not c.Set) so multi-value headers such as + // Set-Cookie are appended rather than silently overwritten. for name, values := range resp.Headers { for _, v := range values { - c.Set(name, v) + c.Response().Header.Add(name, v) } } @@ -305,8 +315,8 @@ func (m *Middleware) saveResult( // Capture response headers for faithful replay. headers := make(map[string][]string) - for key, value := range c.Response().Header.All() { - name := string(key) + for hdrKey, value := range c.Response().Header.All() { + name := string(hdrKey) // Skip headers managed by the middleware itself and // transfer-encoding / content-length which Fiber sets on send. switch name { diff --git a/commons/systemplane/catalog/validate.go b/commons/systemplane/catalog/validate.go index ac3984b7..8aca53dc 100644 --- a/commons/systemplane/catalog/validate.go +++ b/commons/systemplane/catalog/validate.go @@ -32,6 +32,17 @@ func (m Mismatch) String() string { m.CatalogKey, m.Field, m.CatalogValue, m.ProductValue) } +// ValidationResult holds the outcome of ValidateKeyDefsWithOptions, separating +// product-vs-catalog mismatches from catalog-internal env var conflicts. +type ValidationResult struct { + // Mismatches contains divergences between product KeyDefs and catalog SharedKeys. + Mismatches []Mismatch + + // EnvVarConflicts contains catalog-internal duplicate env var mappings + // where two SharedKeys map the same environment variable. + EnvVarConflicts []EnvVarConflict +} + // ValidateOption configures the behavior of ValidateKeyDefs. type ValidateOption func(*validateConfig) @@ -108,28 +119,23 @@ func (vc *validateConfig) shouldSkip(catalogKey, field string) bool { // ValidateKeyDefsWithOptions which accepts ValidateOption values. // ValidateKeyDefs retains its original signature for backward compatibility. func ValidateKeyDefs(productDefs []domain.KeyDef, catalogKeys ...[]SharedKey) []Mismatch { - return ValidateKeyDefsWithOptions(productDefs, catalogKeys, nil) + result := ValidateKeyDefsWithOptions(productDefs, catalogKeys, nil) + return result.Mismatches } // ValidateKeyDefsWithOptions is the full-featured variant of ValidateKeyDefs // that accepts filtering options. Use WithIgnoreFields and WithKnownDeviation // to suppress expected mismatches in product catalog tests. -func ValidateKeyDefsWithOptions(productDefs []domain.KeyDef, catalogKeys [][]SharedKey, opts []ValidateOption) []Mismatch { +// +// Returns a ValidationResult separating product-vs-catalog mismatches from +// catalog-internal env var conflicts. Use result.Mismatches for product drift +// and result.EnvVarConflicts for catalog-internal duplicate env var mappings. +func ValidateKeyDefsWithOptions(productDefs []domain.KeyDef, catalogKeys [][]SharedKey, opts []ValidateOption) ValidationResult { keyIndex, envIndex, envConflicts := buildCatalogIndexes(catalogKeys...) vc := newValidateConfig(opts) var mismatches []Mismatch - for _, conflict := range envConflicts { - mismatches = append(mismatches, Mismatch{ - CatalogKey: conflict.ConflictKey, - ProductKey: conflict.ExistingKey, - Field: "EnvVar", - CatalogValue: conflict.ExistingKey, - ProductValue: conflict.ConflictKey, - }) - } - for _, pd := range productDefs { sk, found, matchedByEnv := resolveSharedKey(pd, keyIndex, envIndex) @@ -154,12 +160,24 @@ func ValidateKeyDefsWithOptions(productDefs []domain.KeyDef, catalogKeys [][]Sha return mismatches[i].Field < mismatches[j].Field }) - return mismatches + sort.Slice(envConflicts, func(i, j int) bool { + if envConflicts[i].EnvVar != envConflicts[j].EnvVar { + return envConflicts[i].EnvVar < envConflicts[j].EnvVar + } + + return envConflicts[i].ExistingKey < envConflicts[j].ExistingKey + }) + + return ValidationResult{ + Mismatches: mismatches, + EnvVarConflicts: envConflicts, + } } func buildCatalogIndexes(catalogKeys ...[]SharedKey) (map[string]SharedKey, map[string]SharedKey, []EnvVarConflict) { keyIndex := make(map[string]SharedKey) envIndex := make(map[string]SharedKey) + ambiguousEnvVars := make(map[string]bool) var conflicts []EnvVarConflict @@ -167,6 +185,18 @@ func buildCatalogIndexes(catalogKeys ...[]SharedKey) (map[string]SharedKey, map[ for _, sk := range slice { keyIndex[sk.Key] = sk for _, envVar := range allowedEnvVars(sk) { + if ambiguousEnvVars[envVar] { + // Already known ambiguous from a prior conflict; record + // this additional claimant but do not re-add to envIndex. + conflicts = append(conflicts, EnvVarConflict{ + EnvVar: envVar, + ExistingKey: "(ambiguous)", + ConflictKey: sk.Key, + }) + + continue + } + if existing, duplicate := envIndex[envVar]; duplicate { conflicts = append(conflicts, EnvVarConflict{ EnvVar: envVar, @@ -174,6 +204,11 @@ func buildCatalogIndexes(catalogKeys ...[]SharedKey) (map[string]SharedKey, map[ ConflictKey: sk.Key, }) + // Mark as ambiguous and remove from envIndex so + // resolveSharedKey will not match this env var. + ambiguousEnvVars[envVar] = true + delete(envIndex, envVar) + continue } diff --git a/commons/systemplane/catalog/validate_options_test.go b/commons/systemplane/catalog/validate_options_test.go index 561ea412..927766e5 100644 --- a/commons/systemplane/catalog/validate_options_test.go +++ b/commons/systemplane/catalog/validate_options_test.go @@ -30,9 +30,9 @@ func TestValidateKeyDefsWithOptions_WithIgnoreFields(t *testing.T) { require.NotEmpty(t, envVarMismatches, "expected EnvVar mismatch without options") // With WithIgnoreFields("EnvVar"): should suppress the EnvVar mismatch. - filtered := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, + result := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, []ValidateOption{WithIgnoreFields("EnvVar")}) - envVarMismatches = filterByField(filtered, "EnvVar") + envVarMismatches = filterByField(result.Mismatches, "EnvVar") assert.Empty(t, envVarMismatches, "EnvVar mismatches should be suppressed by WithIgnoreFields") } @@ -54,9 +54,9 @@ func TestValidateKeyDefsWithOptions_WithKnownDeviation(t *testing.T) { require.NotEmpty(t, componentMismatches, "expected Component mismatch without options") // With WithKnownDeviation: should suppress only that key+field. - filtered := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, + result := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, []ValidateOption{WithKnownDeviation("app.log_level", "Component")}) - componentMismatches = filterByField(filtered, "Component") + componentMismatches = filterByField(result.Mismatches, "Component") assert.Empty(t, componentMismatches, "Component mismatch should be suppressed by WithKnownDeviation") } @@ -76,12 +76,12 @@ func TestValidateKeyDefsWithOptions_KnownDeviationDoesNotSuppressOtherKeys(t *te }, } - filtered := ValidateKeyDefsWithOptions(productDefs, + result := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys(), CORSKeys()}, []ValidateOption{WithKnownDeviation("app.log_level", "Component")}) // app.log_level Component should be suppressed. - for _, mm := range filtered { + for _, mm := range result.Mismatches { if mm.CatalogKey == "app.log_level" && mm.Field == "Component" { t.Error("app.log_level Component deviation should have been suppressed") } @@ -89,7 +89,7 @@ func TestValidateKeyDefsWithOptions_KnownDeviationDoesNotSuppressOtherKeys(t *te // cors.allowed_origins Component should NOT be suppressed. corsComponentFound := false - for _, mm := range filtered { + for _, mm := range result.Mismatches { if mm.CatalogKey == "cors.allowed_origins" && mm.Field == "Component" { corsComponentFound = true } @@ -111,7 +111,7 @@ func TestValidateKeyDefsWithOptions_NilOptions(t *testing.T) { // Nil options should work the same as no options. result := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, nil) - assert.Empty(t, result, "matching key with nil options should produce no mismatches") + assert.Empty(t, result.Mismatches, "matching key with nil options should produce no mismatches") } func TestValidateKeyDefsWithOptions_CombinedOptions(t *testing.T) { @@ -126,13 +126,13 @@ func TestValidateKeyDefsWithOptions_CombinedOptions(t *testing.T) { }, } - filtered := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, + result := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, []ValidateOption{ WithIgnoreFields("EnvVar"), WithKnownDeviation("app.log_level", "Component"), }) - assert.Empty(t, filtered, "all mismatches should be suppressed by combined options") + assert.Empty(t, result.Mismatches, "all mismatches should be suppressed by combined options") } func filterByField(mismatches []Mismatch, field string) []Mismatch { diff --git a/commons/webhook/deliverer.go b/commons/webhook/deliverer.go index 9747dab9..2b9a9055 100644 --- a/commons/webhook/deliverer.go +++ b/commons/webhook/deliverer.go @@ -146,12 +146,13 @@ func WithSecretDecryptor(fn SecretDecryptor) Option { // WithSignatureVersion selects the HMAC signing format for X-Webhook-Signature. // The default is SignatureV0 (payload-only) for backward compatibility. -// Use SignatureV1 to include the event timestamp in the HMAC input and produce -// a versioned signature string that enables replay protection. +// SignatureV1 produces a versioned "v1,sha256=..." signature string that binds +// the event timestamp into the HMAC input, enabling replay protection. +// Receivers can enforce freshness using [VerifySignatureWithFreshness], or +// perform basic signature verification using [VerifySignature]. // // Migration path: switch to SignatureV1 only after all consumers have been -// updated to verify the "v1,sha256=..." format. See VerifySignatureV1 for -// the receiver-side verification logic. +// updated to verify the "v1,sha256=..." format. func WithSignatureVersion(v SignatureVersion) Option { return func(d *Deliverer) { d.sigVersion = v @@ -752,6 +753,7 @@ func sanitizeURL(rawURL string) string { u.RawQuery = "" u.User = nil + u.Fragment = "" return u.String() } diff --git a/commons/webhook/deliverer_test.go b/commons/webhook/deliverer_test.go index 43ec3fd8..bcca3729 100644 --- a/commons/webhook/deliverer_test.go +++ b/commons/webhook/deliverer_test.go @@ -934,6 +934,16 @@ func TestSanitizeURL(t *testing.T) { raw: "https://example.com/webhook", want: "https://example.com/webhook", }, + { + name: "strips fragment", + raw: "https://example.com/webhook#access_token=secret", + want: "https://example.com/webhook", + }, + { + name: "strips fragment with query and userinfo", + raw: "https://user:pass@example.com/hook?key=val#frag", + want: "https://example.com/hook", + }, { name: "invalid URL returns placeholder", raw: "://bad\x7f", @@ -962,9 +972,14 @@ func TestComputeHMACv1(t *testing.T) { secret := "test-secret" timestamp := int64(1700000000) + // Independently computed expected value. + // Wire format: HMAC-SHA256("v1:1700000000.{\"id\":\"123\"}", "test-secret") + // prefixed with "v1,sha256=". + const expected = "v1,sha256=49d0851d6baa34655d12dfba153c748db08b84830d4cb01780d833956c59844c" + sig := computeHMACv1(payload, timestamp, secret) - assert.True(t, strings.HasPrefix(sig, "v1,sha256="), - "v1 signature must start with 'v1,sha256='") + assert.Equal(t, expected, sig, + "v1 signature must match independently computed HMAC-SHA256 of 'v1:.'") // Verify it is deterministic. sig2 := computeHMACv1(payload, timestamp, secret) From c6922d0ef00b9547e6b7053367809dbc00ebba60 Mon Sep 17 00:00:00 2001 From: Fred Amaral Date: Mon, 30 Mar 2026 13:07:14 -0300 Subject: [PATCH 5/9] feat(security/ssrf): add canonical SSRF validation package Introduce commons/security/ssrf as the single source of truth for SSRF protection across all Lerian services. Consolidates two internal, duplicated implementations into one exported package. New exported API: - IsBlockedIP(net.IP) / IsBlockedAddr(netip.Addr): IP-level blocking with canonical CIDR blocklist (8 ranges + stdlib predicates) - IsBlockedHostname(hostname): hostname-level blocking for localhost, cloud metadata endpoints, .local/.internal/.cluster.local suffixes - BlockedPrefixes(): returns copy of CIDR blocklist for auditing - ValidateURL(ctx, url, opts...): scheme + hostname + IP validation without DNS resolution - ResolveAndValidate(ctx, url, opts...): DNS-pinned validation with TOCTOU elimination, returns ResolveResult{PinnedURL, Authority, SNIHostname} - Functional options: WithHTTPSOnly, WithAllowPrivateNetwork, WithLookupFunc, WithAllowHostname - Sentinel errors: ErrBlocked, ErrInvalidURL, ErrDNSFailed Refactored consumers: - commons/webhook/ssrf.go: resolveAndValidateIP delegates to ssrf.ResolveAndValidate, removed duplicated isPrivateIP/CIDR blocklist - commons/net/http/proxy_validation.go: isUnsafeIP delegates to ssrf.IsBlockedIP, removed duplicated blockedProxyPrefixes Canonicalized on netip.Prefix (modern Go) with net.IP bridge for legacy callers. All tests hermetic via WithLookupFunc injection. --- commons/net/http/proxy_validation.go | 38 +- commons/security/ssrf/doc.go | 52 ++ commons/security/ssrf/hostname.go | 90 +++ commons/security/ssrf/options.go | 84 +++ commons/security/ssrf/ssrf.go | 104 ++++ commons/security/ssrf/ssrf_test.go | 890 +++++++++++++++++++++++++++ commons/security/ssrf/validate.go | 200 ++++++ commons/webhook/ssrf.go | 186 +----- commons/webhook/ssrf_test.go | 165 ----- 9 files changed, 1456 insertions(+), 353 deletions(-) create mode 100644 commons/security/ssrf/doc.go create mode 100644 commons/security/ssrf/hostname.go create mode 100644 commons/security/ssrf/options.go create mode 100644 commons/security/ssrf/ssrf.go create mode 100644 commons/security/ssrf/ssrf_test.go create mode 100644 commons/security/ssrf/validate.go diff --git a/commons/net/http/proxy_validation.go b/commons/net/http/proxy_validation.go index 2cc5f8ed..55ceab41 100644 --- a/commons/net/http/proxy_validation.go +++ b/commons/net/http/proxy_validation.go @@ -2,21 +2,11 @@ package http import ( "net" - "net/netip" "net/url" "strings" -) -var blockedProxyPrefixes = []netip.Prefix{ - netip.MustParsePrefix("0.0.0.0/8"), - netip.MustParsePrefix("100.64.0.0/10"), - netip.MustParsePrefix("192.0.0.0/24"), - netip.MustParsePrefix("192.0.2.0/24"), - netip.MustParsePrefix("198.18.0.0/15"), - netip.MustParsePrefix("198.51.100.0/24"), - netip.MustParsePrefix("203.0.113.0/24"), - netip.MustParsePrefix("240.0.0.0/4"), -} + libSSRF "github.com/LerianStudio/lib-commons/v4/commons/security/ssrf" +) // validateProxyTarget checks a parsed URL against the reverse proxy policy. func validateProxyTarget(targetURL *url.URL, policy ReverseProxyPolicy) error { @@ -79,27 +69,7 @@ func isAllowedHost(host string, allowedHosts []string) bool { } // isUnsafeIP reports whether ip is a loopback, private, or otherwise non-routable address. +// It delegates to the canonical SSRF package for the actual blocked-range check. func isUnsafeIP(ip net.IP) bool { - if ip == nil { - return true - } - - if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() { - return true - } - - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return true - } - - addr = addr.Unmap() - - for _, prefix := range blockedProxyPrefixes { - if prefix.Contains(addr) { - return true - } - } - - return false + return libSSRF.IsBlockedIP(ip) } diff --git a/commons/security/ssrf/doc.go b/commons/security/ssrf/doc.go new file mode 100644 index 00000000..463c2347 --- /dev/null +++ b/commons/security/ssrf/doc.go @@ -0,0 +1,52 @@ +// Package ssrf provides server-side request forgery (SSRF) protection for HTTP +// clients that connect to user-controlled URLs. +// +// # Design Decisions +// +// - Fail-closed: every validation error rejects the request. DNS lookup +// failures, unparseable IPs, and empty resolution results all return an +// error rather than silently allowing the request through. +// +// - DNS pinning: [ResolveAndValidate] performs a single DNS lookup, validates +// all resolved IPs, and rewrites the URL to the first safe IP. This +// eliminates the TOCTOU (time-of-check-to-time-of-use) window that exists +// when validation and connection happen in separate steps; a DNS rebinding +// attack cannot change the record between those two operations. +// +// - Single source of truth: the [BlockedPrefixes] list is the canonical CIDR +// blocklist for the entire lib-commons module. Both the webhook deliverer +// and the reverse-proxy helper delegate to this package instead of +// maintaining their own blocklists. +// +// - Modern types: [netip.Prefix] and [netip.Addr] are the canonical types. +// A legacy [net.IP] entry point ([IsBlockedIP]) is provided for callers +// that have not yet migrated, but it delegates to [IsBlockedAddr] after +// conversion. +// +// # Usage +// +// Quick IP check (no DNS): +// +// addr, _ := netip.ParseAddr("10.0.0.1") +// if ssrf.IsBlockedAddr(addr) { +// // reject +// } +// +// Full URL validation with DNS pinning: +// +// result, err := ssrf.ResolveAndValidate(ctx, "https://example.com/hook") +// if err != nil { +// // reject — err wraps ssrf.ErrBlocked, ssrf.ErrInvalidURL, or ssrf.ErrDNSFailed +// } +// // Use result.PinnedURL for the actual HTTP request. +// // Set the Host header to result.Authority. +// // Set TLS ServerName to result.SNIHostname. +// +// Custom DNS resolver for tests: +// +// result, err := ssrf.ResolveAndValidate(ctx, rawURL, +// ssrf.WithLookupFunc(func(_ context.Context, _ string) ([]string, error) { +// return []string{"93.184.216.34"}, nil +// }), +// ) +package ssrf diff --git a/commons/security/ssrf/hostname.go b/commons/security/ssrf/hostname.go new file mode 100644 index 00000000..86c4468a --- /dev/null +++ b/commons/security/ssrf/hostname.go @@ -0,0 +1,90 @@ +package ssrf + +import "strings" + +// blockedHostnames contains exact hostnames that must be rejected regardless of +// IP resolution. Case-insensitive comparison is used. +// +//nolint:gochecknoglobals // package-level hostname blocklist is intentional for SSRF protection +var blockedHostnames = map[string]bool{ + "localhost": true, + "metadata.google.internal": true, + "metadata.gcp.internal": true, + // AWS metadata IP as a hostname (also caught by IP check — defense-in-depth). + "169.254.169.254": true, +} + +// blockedSuffixes contains hostname suffixes that indicate internal or +// non-routable DNS names. Any hostname ending with one of these suffixes is +// rejected. +// +// Note on ".internal": this blocks cloud metadata endpoints and internal DNS +// names. In corporate environments that use ".internal" as a legitimate TLD, +// use [WithAllowHostname] to exempt specific hosts. +// +//nolint:gochecknoglobals // package-level suffix blocklist is intentional for SSRF protection +var blockedSuffixes = []string{ + ".local", // mDNS / Bonjour (RFC 6762) + ".internal", // cloud metadata, internal DNS + ".cluster.local", // Kubernetes internal DNS +} + +// IsBlockedHostname reports whether hostname matches known dangerous patterns. +// +// Checks performed (case-insensitive): +// - Empty hostname. +// - Exact match against known blocked hostnames (localhost, cloud metadata +// endpoints, AWS metadata IP). +// - Suffix match against dangerous suffixes (.local, .internal, +// .cluster.local). +// +// Note: the ".internal" suffix blocks cloud metadata and internal DNS names but +// may affect legitimate ".internal" domains in corporate environments. Use +// [WithAllowHostname] to exempt specific hostnames when calling +// [ValidateURL] or [ResolveAndValidate]. +func IsBlockedHostname(hostname string) bool { + if hostname == "" { + return true + } + + lower := strings.ToLower(hostname) + + if blockedHostnames[lower] { + return true + } + + for _, suffix := range blockedSuffixes { + if strings.HasSuffix(lower, suffix) { + return true + } + } + + return false +} + +// isBlockedHostnameWithConfig performs the same check as [IsBlockedHostname] +// but respects the allowedHostnames override from functional options. +func isBlockedHostnameWithConfig(hostname string, cfg *config) bool { + if hostname == "" { + return true + } + + lower := strings.ToLower(hostname) + + // Check allow-list override first. + if cfg != nil && cfg.allowedHostnames[lower] { + return false + } + + if blockedHostnames[lower] { + return true + } + + for _, suffix := range blockedSuffixes { + if strings.HasSuffix(lower, suffix) { + return true + } + } + + return false +} diff --git a/commons/security/ssrf/options.go b/commons/security/ssrf/options.go new file mode 100644 index 00000000..d509d4c8 --- /dev/null +++ b/commons/security/ssrf/options.go @@ -0,0 +1,84 @@ +package ssrf + +import ( + "context" + "strings" +) + +// LookupFunc is the signature for a DNS resolver function. It mirrors +// [net.Resolver.LookupHost] and is used by [WithLookupFunc] to inject a custom +// resolver for testing or special environments. +type LookupFunc func(ctx context.Context, host string) ([]string, error) + +// Option configures the behaviour of [ValidateURL] and [ResolveAndValidate]. +type Option func(*config) + +// config holds the resolved configuration built from functional [Option] values. +type config struct { + // httpsOnly rejects URLs whose scheme is not "https". + httpsOnly bool + + // allowPrivate bypasses IP blocking entirely. Intended for local + // development and testing only — never enable in production. + allowPrivate bool + + // lookupFunc overrides the default DNS resolver. When nil, the default + // net.DefaultResolver.LookupHost is used. + lookupFunc LookupFunc + + // allowedHostnames exempts specific hostnames from [IsBlockedHostname] + // checks. Keys are stored lower-cased. + allowedHostnames map[string]bool +} + +// buildConfig applies all options to a default config. +func buildConfig(opts []Option) *config { + cfg := &config{} + + for _, opt := range opts { + opt(cfg) + } + + return cfg +} + +// WithHTTPSOnly rejects any URL whose scheme is not "https". Use this when the +// target must always be contacted over TLS. +func WithHTTPSOnly() Option { + return func(c *config) { + c.httpsOnly = true + } +} + +// WithAllowPrivateNetwork bypasses IP blocking entirely, allowing connections +// to loopback, private, and link-local addresses. This is intended exclusively +// for local development and testing — never enable it in production. +func WithAllowPrivateNetwork() Option { + return func(c *config) { + c.allowPrivate = true + } +} + +// WithLookupFunc sets a custom DNS resolver. This is primarily useful in tests +// to avoid real DNS lookups and to exercise specific resolution scenarios (e.g. +// all IPs blocked, mixed safe/blocked IPs). +func WithLookupFunc(fn LookupFunc) Option { + return func(c *config) { + c.lookupFunc = fn + } +} + +// WithAllowHostname exempts a specific hostname from [IsBlockedHostname] +// checks. The comparison is case-insensitive. This is useful in corporate +// environments where a legitimate service runs on a ".internal" domain. +// +// Multiple calls accumulate — each call adds one hostname to the allow-list. +func WithAllowHostname(hostname string) Option { + return func(c *config) { + if c.allowedHostnames == nil { + c.allowedHostnames = make(map[string]bool) + } + + c.allowedHostnames[strings.ToLower(hostname)] = true + } +} diff --git a/commons/security/ssrf/ssrf.go b/commons/security/ssrf/ssrf.go new file mode 100644 index 00000000..8751feab --- /dev/null +++ b/commons/security/ssrf/ssrf.go @@ -0,0 +1,104 @@ +package ssrf + +import ( + "errors" + "net" + "net/netip" +) + +// Sentinel errors returned by validation functions. Callers should use +// [errors.Is] to check the category because concrete errors wrap these +// sentinels with additional context via [fmt.Errorf] and %w. +var ( + // ErrBlocked is returned when a URL or IP is rejected by SSRF protection. + ErrBlocked = errors.New("ssrf: blocked") + + // ErrInvalidURL is returned when a URL cannot be parsed or is structurally + // invalid (empty hostname, missing scheme, etc.). + ErrInvalidURL = errors.New("ssrf: invalid URL") + + // ErrDNSFailed is returned when DNS resolution fails for a hostname. + ErrDNSFailed = errors.New("ssrf: DNS resolution failed") +) + +// blockedPrefixes is the canonical CIDR blocklist. It covers RFC-defined +// special-purpose ranges that are not caught by the standard library predicates +// (IsLoopback, IsPrivate, IsLinkLocalUnicast, etc.). +// +// Each entry is intentionally a netip.Prefix literal so that typos are caught +// at init time rather than silently ignored at request time. +// +//nolint:gochecknoglobals // package-level CIDR blocklist is intentional for SSRF protection +var blockedPrefixes = []netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/8"), // "this network" (RFC 1122 S3.2.1.3) + netip.MustParsePrefix("100.64.0.0/10"), // CGNAT (RFC 6598) + netip.MustParsePrefix("192.0.0.0/24"), // IETF protocol assignments (RFC 6890) + netip.MustParsePrefix("192.0.2.0/24"), // TEST-NET-1 documentation (RFC 5737) + netip.MustParsePrefix("198.18.0.0/15"), // benchmarking (RFC 2544) + netip.MustParsePrefix("198.51.100.0/24"), // TEST-NET-2 documentation (RFC 5737) + netip.MustParsePrefix("203.0.113.0/24"), // TEST-NET-3 documentation (RFC 5737) + netip.MustParsePrefix("240.0.0.0/4"), // reserved / future use (RFC 1112) +} + +// BlockedPrefixes returns a copy of the canonical CIDR blocklist. The returned +// slice is safe to modify without affecting the package state. +func BlockedPrefixes() []netip.Prefix { + out := make([]netip.Prefix, len(blockedPrefixes)) + copy(out, blockedPrefixes) + + return out +} + +// IsBlockedAddr reports whether addr falls in a private, loopback, link-local, +// multicast, unspecified, or other reserved range that must not be contacted by +// outbound HTTP requests. +// +// Check order: +// 1. Unmap IPv4-mapped IPv6 addresses (e.g. ::ffff:127.0.0.1 becomes 127.0.0.1). +// 2. Standard library predicates: IsLoopback, IsPrivate, IsLinkLocalUnicast, +// IsLinkLocalMulticast, IsMulticast, IsUnspecified. +// 3. Custom blocklist ([BlockedPrefixes]). +func IsBlockedAddr(addr netip.Addr) bool { + // Step 1: unmap IPv4-mapped IPv6 so that ::ffff:10.0.0.1 is treated + // identically to 10.0.0.1. + addr = addr.Unmap() + + // Step 2: standard library predicates. + if addr.IsLoopback() || + addr.IsPrivate() || + addr.IsLinkLocalUnicast() || + addr.IsLinkLocalMulticast() || + addr.IsMulticast() || + addr.IsUnspecified() { + return true + } + + // Step 3: custom CIDR blocklist for ranges not covered by the predicates + // above (CGNAT, TEST-NETs, benchmarking, reserved, etc.). + for _, prefix := range blockedPrefixes { + if prefix.Contains(addr) { + return true + } + } + + return false +} + +// IsBlockedIP reports whether ip falls in a blocked range. This is a +// convenience wrapper for callers that still use the legacy [net.IP] type; it +// converts to [netip.Addr] and delegates to [IsBlockedAddr]. +// +// A nil ip is considered blocked (fail-closed). +func IsBlockedIP(ip net.IP) bool { + if ip == nil { + return true + } + + addr, ok := netip.AddrFromSlice(ip) + if !ok { + // Unparseable IP — fail-closed. + return true + } + + return IsBlockedAddr(addr) +} diff --git a/commons/security/ssrf/ssrf_test.go b/commons/security/ssrf/ssrf_test.go new file mode 100644 index 00000000..127bc628 --- /dev/null +++ b/commons/security/ssrf/ssrf_test.go @@ -0,0 +1,890 @@ +//go:build unit + +package ssrf + +import ( + "context" + "errors" + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// BlockedPrefixes +// --------------------------------------------------------------------------- + +// expectedPrefixCount lives in the test file so the production code does not +// export or carry a constant that is only meaningful for test assertions. +// Update this value when adding new CIDR ranges to blockedPrefixes. +const expectedPrefixCount = 8 + +func TestBlockedPrefixes_ReturnsExpectedCount(t *testing.T) { + t.Parallel() + + prefixes := BlockedPrefixes() + assert.Len(t, prefixes, expectedPrefixCount, + "BlockedPrefixes() should return exactly %d prefixes", expectedPrefixCount) +} + +func TestBlockedPrefixes_ReturnsCopy(t *testing.T) { + t.Parallel() + + a := BlockedPrefixes() + b := BlockedPrefixes() + + // Mutate the first slice and verify the second is unaffected. + a[0] = netip.MustParsePrefix("255.255.255.255/32") + assert.NotEqual(t, a[0], b[0], + "BlockedPrefixes() must return independent copies") +} + +func TestBlockedPrefixes_ContainsExpectedRanges(t *testing.T) { + t.Parallel() + + expected := []string{ + "0.0.0.0/8", + "100.64.0.0/10", + "192.0.0.0/24", + "192.0.2.0/24", + "198.18.0.0/15", + "198.51.100.0/24", + "203.0.113.0/24", + "240.0.0.0/4", + } + + prefixes := BlockedPrefixes() + strs := make([]string, len(prefixes)) + + for i, p := range prefixes { + strs[i] = p.String() + } + + for _, exp := range expected { + assert.Contains(t, strs, exp, + "BlockedPrefixes() should contain %s", exp) + } +} + +// --------------------------------------------------------------------------- +// IsBlockedAddr — netip.Addr interface +// --------------------------------------------------------------------------- + +func TestIsBlockedAddr(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + addr string + blocked bool + }{ + // --- Loopback --- + {name: "IPv4 loopback 127.0.0.1", addr: "127.0.0.1", blocked: true}, + {name: "IPv4 loopback 127.0.0.2", addr: "127.0.0.2", blocked: true}, + {name: "IPv4 loopback 127.255.255.255", addr: "127.255.255.255", blocked: true}, + {name: "IPv6 loopback ::1", addr: "::1", blocked: true}, + + // --- RFC 1918 private ranges --- + {name: "10.0.0.1 (Class A)", addr: "10.0.0.1", blocked: true}, + {name: "10.255.255.255 (Class A top)", addr: "10.255.255.255", blocked: true}, + {name: "172.16.0.1 (Class B)", addr: "172.16.0.1", blocked: true}, + {name: "172.31.255.255 (Class B top)", addr: "172.31.255.255", blocked: true}, + {name: "192.168.0.1 (Class C)", addr: "192.168.0.1", blocked: true}, + {name: "192.168.255.255 (Class C top)", addr: "192.168.255.255", blocked: true}, + + // --- Link-local --- + {name: "IPv4 link-local 169.254.1.1", addr: "169.254.1.1", blocked: true}, + {name: "IPv4 link-local AWS metadata", addr: "169.254.169.254", blocked: true}, + {name: "IPv6 link-local unicast fe80::1", addr: "fe80::1", blocked: true}, + + // --- Link-local multicast --- + {name: "IPv4 multicast 224.0.0.1", addr: "224.0.0.1", blocked: true}, + {name: "IPv4 multicast 239.255.255.255", addr: "239.255.255.255", blocked: true}, + + // --- Unspecified --- + {name: "IPv4 unspecified 0.0.0.0", addr: "0.0.0.0", blocked: true}, + {name: "IPv6 unspecified ::", addr: "::", blocked: true}, + + // --- CGNAT (RFC 6598): 100.64.0.0/10 --- + {name: "CGNAT low 100.64.0.1", addr: "100.64.0.1", blocked: true}, + {name: "CGNAT mid 100.100.100.100", addr: "100.100.100.100", blocked: true}, + {name: "CGNAT high 100.127.255.254", addr: "100.127.255.254", blocked: true}, + + // --- "this network" (RFC 1122): 0.0.0.0/8 --- + {name: "0.0.0.1 this-network", addr: "0.0.0.1", blocked: true}, + {name: "0.255.255.255 this-network top", addr: "0.255.255.255", blocked: true}, + + // --- IETF protocol assignments (RFC 6890): 192.0.0.0/24 --- + {name: "192.0.0.1 IETF", addr: "192.0.0.1", blocked: true}, + {name: "192.0.0.254 IETF top", addr: "192.0.0.254", blocked: true}, + + // --- TEST-NET-1 (RFC 5737): 192.0.2.0/24 --- + {name: "192.0.2.1 TEST-NET-1", addr: "192.0.2.1", blocked: true}, + + // --- Benchmarking (RFC 2544): 198.18.0.0/15 --- + {name: "198.18.0.1 benchmarking", addr: "198.18.0.1", blocked: true}, + {name: "198.19.255.255 benchmarking top", addr: "198.19.255.255", blocked: true}, + + // --- TEST-NET-2 (RFC 5737): 198.51.100.0/24 --- + {name: "198.51.100.1 TEST-NET-2", addr: "198.51.100.1", blocked: true}, + {name: "198.51.100.10 TEST-NET-2", addr: "198.51.100.10", blocked: true}, + + // --- TEST-NET-3 (RFC 5737): 203.0.113.0/24 --- + {name: "203.0.113.1 TEST-NET-3", addr: "203.0.113.1", blocked: true}, + {name: "203.0.113.10 TEST-NET-3", addr: "203.0.113.10", blocked: true}, + + // --- Reserved (RFC 1112): 240.0.0.0/4 --- + {name: "240.0.0.1 reserved", addr: "240.0.0.1", blocked: true}, + {name: "255.255.255.254 reserved top", addr: "255.255.255.254", blocked: true}, + + // --- IPv4-mapped IPv6 (must unmap and check) --- + {name: "IPv4-mapped loopback ::ffff:127.0.0.1", addr: "::ffff:127.0.0.1", blocked: true}, + {name: "IPv4-mapped private ::ffff:10.0.0.1", addr: "::ffff:10.0.0.1", blocked: true}, + {name: "IPv4-mapped CGNAT ::ffff:100.64.0.1", addr: "::ffff:100.64.0.1", blocked: true}, + {name: "IPv4-mapped TEST-NET-2 ::ffff:198.51.100.10", addr: "::ffff:198.51.100.10", blocked: true}, + + // --- Public IPs — must NOT be blocked --- + {name: "Google DNS 8.8.8.8", addr: "8.8.8.8", blocked: false}, + {name: "Cloudflare 1.1.1.1", addr: "1.1.1.1", blocked: false}, + {name: "Public IPv6 2001:4860:4860::8888", addr: "2001:4860:4860::8888", blocked: false}, + {name: "93.184.216.34 (example.com)", addr: "93.184.216.34", blocked: false}, + + // --- Boundary cases — just outside blocked ranges --- + {name: "100.63.255.255 below CGNAT", addr: "100.63.255.255", blocked: false}, + {name: "100.128.0.1 above CGNAT", addr: "100.128.0.1", blocked: false}, + {name: "172.15.255.255 below Class B", addr: "172.15.255.255", blocked: false}, + {name: "172.32.0.1 above Class B", addr: "172.32.0.1", blocked: false}, + {name: "1.0.0.1 above this-network", addr: "1.0.0.1", blocked: false}, + {name: "198.17.255.255 below benchmarking", addr: "198.17.255.255", blocked: false}, + {name: "198.20.0.1 above benchmarking", addr: "198.20.0.1", blocked: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + addr := netip.MustParseAddr(tt.addr) + assert.Equal(t, tt.blocked, IsBlockedAddr(addr), + "IsBlockedAddr(%s) = %v, want %v", tt.addr, IsBlockedAddr(addr), tt.blocked) + }) + } +} + +// --------------------------------------------------------------------------- +// IsBlockedIP — legacy net.IP interface +// --------------------------------------------------------------------------- + +func TestIsBlockedIP(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ip string + blocked bool + }{ + {name: "IPv4 loopback", ip: "127.0.0.1", blocked: true}, + {name: "IPv6 loopback", ip: "::1", blocked: true}, + {name: "RFC 1918 10.x", ip: "10.0.0.1", blocked: true}, + {name: "RFC 1918 172.16.x", ip: "172.16.0.1", blocked: true}, + {name: "RFC 1918 192.168.x", ip: "192.168.0.1", blocked: true}, + {name: "Link-local", ip: "169.254.1.1", blocked: true}, + {name: "CGNAT", ip: "100.64.0.1", blocked: true}, + {name: "TEST-NET-1", ip: "192.0.2.1", blocked: true}, + {name: "Multicast", ip: "224.0.0.1", blocked: true}, + {name: "Google DNS", ip: "8.8.8.8", blocked: false}, + {name: "Cloudflare", ip: "1.1.1.1", blocked: false}, + {name: "Public IPv6", ip: "2001:4860:4860::8888", blocked: false}, + // IPv4-mapped IPv6 + {name: "IPv4-mapped loopback", ip: "::ffff:127.0.0.1", blocked: true}, + {name: "IPv4-mapped private", ip: "::ffff:10.0.0.1", blocked: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ip := net.ParseIP(tt.ip) + require.NotNil(t, ip, "failed to parse IP: %s", tt.ip) + assert.Equal(t, tt.blocked, IsBlockedIP(ip), + "IsBlockedIP(%s) = %v, want %v", tt.ip, IsBlockedIP(ip), tt.blocked) + }) + } +} + +func TestIsBlockedIP_NilIP(t *testing.T) { + t.Parallel() + + assert.True(t, IsBlockedIP(nil), "IsBlockedIP(nil) must return true (fail-closed)") +} + +// --------------------------------------------------------------------------- +// IsBlockedHostname +// --------------------------------------------------------------------------- + +func TestIsBlockedHostname(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + host string + blocked bool + }{ + // Exact blocked hostnames + {name: "localhost", host: "localhost", blocked: true}, + {name: "LOCALHOST uppercase", host: "LOCALHOST", blocked: true}, + {name: "Localhost mixed", host: "Localhost", blocked: true}, + {name: "metadata.google.internal", host: "metadata.google.internal", blocked: true}, + {name: "metadata.gcp.internal", host: "metadata.gcp.internal", blocked: true}, + {name: "AWS metadata IP", host: "169.254.169.254", blocked: true}, + + // Dangerous suffixes + {name: ".local suffix", host: "myhost.local", blocked: true}, + {name: ".internal suffix", host: "service.internal", blocked: true}, + {name: ".cluster.local suffix", host: "api.default.svc.cluster.local", blocked: true}, + {name: ".INTERNAL uppercase", host: "service.INTERNAL", blocked: true}, + + // Empty hostname + {name: "empty string", host: "", blocked: true}, + + // Legitimate hostnames — must NOT be blocked + {name: "example.com", host: "example.com", blocked: false}, + {name: "api.example.com", host: "api.example.com", blocked: false}, + {name: "hooks.stripe.com", host: "hooks.stripe.com", blocked: false}, + {name: "8.8.8.8 public", host: "8.8.8.8", blocked: false}, + {name: "2001:4860:4860::8888", host: "2001:4860:4860::8888", blocked: false}, + + // Tricky near-misses that must NOT be blocked + {name: "localhostx not blocked", host: "localhostx", blocked: false}, + {name: "mylocal not blocked", host: "mylocal", blocked: false}, + {name: "internal.example.com not blocked", host: "internal.example.com", blocked: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.blocked, IsBlockedHostname(tt.host), + "IsBlockedHostname(%q) = %v, want %v", tt.host, IsBlockedHostname(tt.host), tt.blocked) + }) + } +} + +// --------------------------------------------------------------------------- +// ValidateURL +// --------------------------------------------------------------------------- + +func TestValidateURL_ValidPublicURLs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "https example.com", url: "https://example.com/hook"}, + {name: "http example.com", url: "http://example.com/hook"}, + {name: "https with port", url: "https://example.com:8443/hook"}, + {name: "https with path and query", url: "https://example.com/a/b?c=d"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), tt.url) + assert.NoError(t, err) + }) + } +} + +func TestValidateURL_BlockedSchemes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "gopher", url: "gopher://evil.com"}, + {name: "file", url: "file:///etc/passwd"}, + {name: "ftp", url: "ftp://example.com/file"}, + {name: "javascript", url: "javascript:alert(1)"}, + {name: "data", url: "data:text/html,

hi

"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), tt.url) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + } +} + +func TestValidateURL_BlockedHostnames(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "localhost", url: "http://localhost/hook"}, + {name: "metadata.google.internal", url: "https://metadata.google.internal/computeMetadata"}, + {name: ".local suffix", url: "http://printer.local/status"}, + {name: "k8s internal", url: "http://api.default.svc.cluster.local/health"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), tt.url) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + } +} + +func TestValidateURL_BlockedIPLiteral(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "loopback", url: "http://127.0.0.1/hook"}, + {name: "private 10.x", url: "https://10.0.0.1/hook"}, + {name: "private 192.168.x", url: "https://192.168.1.1/hook"}, + {name: "CGNAT", url: "http://100.64.0.1/hook"}, + {name: "IPv6 loopback", url: "http://[::1]/hook"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), tt.url) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + } +} + +func TestValidateURL_EmptyHostname(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "http://") + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) +} + +func TestValidateURL_MalformedURL(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "://missing-scheme") + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) +} + +func TestValidateURL_HTTPSOnly(t *testing.T) { + t.Parallel() + + t.Run("https allowed", func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "https://example.com/hook", WithHTTPSOnly()) + assert.NoError(t, err) + }) + + t.Run("http rejected", func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "http://example.com/hook", WithHTTPSOnly()) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + assert.Contains(t, err.Error(), "HTTPS only") + }) +} + +func TestValidateURL_AllowPrivateNetwork(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "loopback IP", url: "http://127.0.0.1/hook"}, + {name: "private IP", url: "http://10.0.0.1/hook"}, + {name: "CGNAT IP", url: "http://100.64.0.1/hook"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), tt.url, WithAllowPrivateNetwork()) + // AllowPrivateNetwork bypasses IP blocking but hostname blocking + // is still evaluated. Since these are IP literals (not "localhost"), + // they should pass. + assert.NoError(t, err) + }) + } +} + +func TestValidateURL_AllowHostname(t *testing.T) { + t.Parallel() + + t.Run("exempted hostname passes", func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "http://service.internal/hook", + WithAllowHostname("service.internal")) + assert.NoError(t, err) + }) + + t.Run("non-exempted hostname still blocked", func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "http://other.internal/hook", + WithAllowHostname("service.internal")) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + + t.Run("case insensitive", func(t *testing.T) { + t.Parallel() + + err := ValidateURL(context.Background(), "http://Service.Internal/hook", + WithAllowHostname("service.internal")) + assert.NoError(t, err) + }) +} + +// --------------------------------------------------------------------------- +// ResolveAndValidate +// --------------------------------------------------------------------------- + +// fakeLookup returns a LookupFunc that yields the given IPs. +func fakeLookup(ips ...string) LookupFunc { + return func(_ context.Context, _ string) ([]string, error) { + return ips, nil + } +} + +// failLookup returns a LookupFunc that always fails. +func failLookup() LookupFunc { + return func(_ context.Context, _ string) ([]string, error) { + return nil, errors.New("simulated DNS failure") + } +} + +func TestResolveAndValidate_PublicIP(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), + "https://example.com:8443/hook", + WithLookupFunc(fakeLookup("93.184.216.34")), + ) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "https://93.184.216.34:8443/hook", result.PinnedURL) + assert.Equal(t, "example.com:8443", result.Authority) + assert.Equal(t, "example.com", result.SNIHostname) +} + +func TestResolveAndValidate_PublicIPNoPort(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("93.184.216.34")), + ) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "https://93.184.216.34/hook", result.PinnedURL) + assert.Equal(t, "example.com", result.Authority) + assert.Equal(t, "example.com", result.SNIHostname) +} + +func TestResolveAndValidate_IPv6Public(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("2001:4860:4860::8888")), + ) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "https://[2001:4860:4860::8888]/hook", result.PinnedURL) + assert.Equal(t, "example.com", result.Authority) + assert.Equal(t, "example.com", result.SNIHostname) +} + +func TestResolveAndValidate_BlockedIP(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ip string + }{ + {name: "loopback", ip: "127.0.0.1"}, + {name: "private 10.x", ip: "10.0.0.1"}, + {name: "private 192.168.x", ip: "192.168.1.1"}, + {name: "CGNAT", ip: "100.64.0.1"}, + {name: "link-local", ip: "169.254.169.254"}, + {name: "TEST-NET-1", ip: "192.0.2.1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup(tt.ip)), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + } +} + +func TestResolveAndValidate_DNSFailure(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(failLookup()), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrDNSFailed) +} + +func TestResolveAndValidate_DNSEmptyResult(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup()), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrDNSFailed) +} + +func TestResolveAndValidate_BlockedScheme(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "gopher", url: "gopher://evil.com"}, + {name: "file", url: "file:///etc/passwd"}, + {name: "ftp", url: "ftp://example.com/file"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), tt.url, + WithLookupFunc(fakeLookup("93.184.216.34")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + } +} + +func TestResolveAndValidate_BlockedHostname(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "localhost", url: "http://localhost/hook"}, + {name: "metadata.google.internal", url: "https://metadata.google.internal/computeMetadata"}, + {name: ".local suffix", url: "http://printer.local/status"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), tt.url, + WithLookupFunc(fakeLookup("93.184.216.34")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) + } +} + +func TestResolveAndValidate_EmptyHostname(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), "http://", + WithLookupFunc(fakeLookup("93.184.216.34")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) +} + +func TestResolveAndValidate_MalformedURL(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), "://no-scheme", + WithLookupFunc(fakeLookup("93.184.216.34")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) +} + +func TestResolveAndValidate_HTTPSOnly(t *testing.T) { + t.Parallel() + + t.Run("https passes", func(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("93.184.216.34")), + WithHTTPSOnly(), + ) + + require.NoError(t, err) + assert.Contains(t, result.PinnedURL, "https://") + }) + + t.Run("http rejected", func(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), + "http://example.com/hook", + WithLookupFunc(fakeLookup("93.184.216.34")), + WithHTTPSOnly(), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) +} + +func TestResolveAndValidate_AllowPrivateNetwork(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), + "http://example.com/hook", + WithLookupFunc(fakeLookup("10.0.0.1")), + WithAllowPrivateNetwork(), + ) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "http://10.0.0.1/hook", result.PinnedURL) +} + +func TestResolveAndValidate_AllowHostname(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), + "http://service.internal/hook", + WithLookupFunc(fakeLookup("93.184.216.34")), + WithAllowHostname("service.internal"), + ) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "http://93.184.216.34/hook", result.PinnedURL) + assert.Equal(t, "service.internal", result.SNIHostname) +} + +func TestResolveAndValidate_MultipleIPs_FirstSafe(t *testing.T) { + t.Parallel() + + // First IP is safe — should be selected for pinning. + result, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("93.184.216.34", "1.1.1.1")), + ) + + require.NoError(t, err) + assert.Equal(t, "https://93.184.216.34/hook", result.PinnedURL) +} + +func TestResolveAndValidate_MultipleIPs_BlockedAmongSafe(t *testing.T) { + t.Parallel() + + // First IP is blocked — fail-closed: reject the entire request. + _, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("10.0.0.1", "93.184.216.34")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) +} + +func TestResolveAndValidate_AllIPsBlocked(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("127.0.0.1", "10.0.0.1", "192.168.1.1")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) +} + +func TestResolveAndValidate_UnparseableIPSkipped(t *testing.T) { + t.Parallel() + + // DNS returns an unparseable entry followed by a valid one. + result, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("not-an-ip", "93.184.216.34")), + ) + + require.NoError(t, err) + assert.Equal(t, "https://93.184.216.34/hook", result.PinnedURL) +} + +func TestResolveAndValidate_AllUnparseable(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), + "https://example.com/hook", + WithLookupFunc(fakeLookup("not-an-ip", "also-not-an-ip")), + ) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) +} + +// --------------------------------------------------------------------------- +// Options composition +// --------------------------------------------------------------------------- + +func TestOptions_MultipleAllowHostname(t *testing.T) { + t.Parallel() + + opts := []Option{ + WithAllowHostname("a.internal"), + WithAllowHostname("b.internal"), + WithLookupFunc(fakeLookup("93.184.216.34")), + } + + t.Run("first allowed", func(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), "http://a.internal/hook", opts...) + require.NoError(t, err) + assert.Equal(t, "a.internal", result.SNIHostname) + }) + + t.Run("second allowed", func(t *testing.T) { + t.Parallel() + + result, err := ResolveAndValidate(context.Background(), "http://b.internal/hook", opts...) + require.NoError(t, err) + assert.Equal(t, "b.internal", result.SNIHostname) + }) + + t.Run("third still blocked", func(t *testing.T) { + t.Parallel() + + _, err := ResolveAndValidate(context.Background(), "http://c.internal/hook", opts...) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) + }) +} + +// --------------------------------------------------------------------------- +// validateScheme (internal, tested via ValidateURL/ResolveAndValidate above, +// but explicit coverage for edge cases) +// --------------------------------------------------------------------------- + +func TestValidateScheme_CaseInsensitive(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + scheme string + }{ + {name: "HTTP uppercase", scheme: "HTTP"}, + {name: "HTTPS uppercase", scheme: "HTTPS"}, + {name: "Mixed Http", scheme: "Http"}, + {name: "Mixed hTtPs", scheme: "hTtPs"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := validateScheme(tt.scheme, &config{}) + assert.NoError(t, err, "scheme %q must be allowed", tt.scheme) + }) + } +} + +func TestValidateScheme_EmptyScheme(t *testing.T) { + t.Parallel() + + err := validateScheme("", &config{}) + require.Error(t, err) + assert.ErrorIs(t, err, ErrBlocked) +} + +// --------------------------------------------------------------------------- +// isBlockedHostnameWithConfig +// --------------------------------------------------------------------------- + +func TestIsBlockedHostnameWithConfig_NilConfig(t *testing.T) { + t.Parallel() + + // nil config should behave like IsBlockedHostname. + assert.True(t, isBlockedHostnameWithConfig("localhost", nil)) + assert.False(t, isBlockedHostnameWithConfig("example.com", nil)) +} + +func TestIsBlockedHostnameWithConfig_AllowOverride(t *testing.T) { + t.Parallel() + + cfg := &config{ + allowedHostnames: map[string]bool{ + "special.internal": true, + }, + } + + assert.False(t, isBlockedHostnameWithConfig("special.internal", cfg), + "allowed hostname should not be blocked") + assert.True(t, isBlockedHostnameWithConfig("other.internal", cfg), + "non-allowed .internal hostname should still be blocked") +} + +// --------------------------------------------------------------------------- +// Error sentinel identity +// --------------------------------------------------------------------------- + +func TestSentinelErrors_AreDistinct(t *testing.T) { + t.Parallel() + + assert.NotErrorIs(t, ErrBlocked, ErrInvalidURL) + assert.NotErrorIs(t, ErrBlocked, ErrDNSFailed) + assert.NotErrorIs(t, ErrInvalidURL, ErrDNSFailed) +} diff --git a/commons/security/ssrf/validate.go b/commons/security/ssrf/validate.go new file mode 100644 index 00000000..d1524665 --- /dev/null +++ b/commons/security/ssrf/validate.go @@ -0,0 +1,200 @@ +package ssrf + +import ( + "context" + "fmt" + "net" + "net/netip" + "net/url" + "strings" +) + +// ResolveResult holds the output of [ResolveAndValidate]. Callers should use +// [PinnedURL] for the actual HTTP request, set the Host header to [Authority], +// and configure TLS ServerName to [SNIHostname]. +type ResolveResult struct { + // PinnedURL is the original URL with the hostname replaced by the first + // safe resolved IP address. Using this URL for the HTTP request eliminates + // the TOCTOU window between DNS validation and connection. + PinnedURL string + + // Authority is the original host:port value (url.URL.Host) before DNS + // pinning. It preserves explicit non-default ports and should be used as + // the HTTP Host header value so the target server routes correctly. + Authority string + + // SNIHostname is the bare hostname without port (url.URL.Hostname()). It + // should be used for TLS SNI / certificate verification. + SNIHostname string +} + +// ValidateURL checks a URL for SSRF safety without performing DNS resolution. +// It validates the scheme, hostname blocking, and any IP literal in the +// hostname. +// +// Use [ResolveAndValidate] when DNS pinning is needed (i.e. when you intend to +// actually connect to the URL). ValidateURL is suitable for pre-flight +// validation where DNS resolution is deferred or not desired. +// +// Errors wrap [ErrBlocked] or [ErrInvalidURL] for programmatic inspection via +// [errors.Is]. +func ValidateURL(ctx context.Context, rawURL string, opts ...Option) error { + _ = ctx // reserved for future use (e.g. tracing) + + cfg := buildConfig(opts) + + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("%w: %w", ErrInvalidURL, err) + } + + if err := validateScheme(u.Scheme, cfg); err != nil { + return err + } + + hostname := u.Hostname() + if hostname == "" { + return fmt.Errorf("%w: empty hostname", ErrInvalidURL) + } + + // Check hostname-level blocking (localhost, metadata endpoints, etc.). + if isBlockedHostnameWithConfig(hostname, cfg) { + return fmt.Errorf("%w: hostname %q is blocked", ErrBlocked, hostname) + } + + // If the hostname is an IP literal, validate it against the CIDR blocklist. + if !cfg.allowPrivate { + if addr, err := netip.ParseAddr(hostname); err == nil { + if IsBlockedAddr(addr) { + return fmt.Errorf("%w: IP %s is in a blocked range", ErrBlocked, hostname) + } + } + } + + return nil +} + +// ResolveAndValidate performs DNS resolution, validates all resolved IPs +// against the SSRF blocklist, and returns a pinned URL. This eliminates the +// TOCTOU window between "validate" and "connect" by combining both into a +// single DNS lookup. +// +// Flow: +// 1. Parse URL, validate scheme (http/https only, or https-only with [WithHTTPSOnly]). +// 2. Check hostname blocking ([IsBlockedHostname]). +// 3. Resolve DNS (single lookup via [net.DefaultResolver] or [WithLookupFunc]). +// 4. Validate ALL resolved IPs ([IsBlockedAddr]). +// 5. Pin URL to first safe IP, return [ResolveResult]. +// +// DNS lookup failures are fail-closed: if the hostname cannot be resolved the +// URL is rejected. When every resolved IP is blocked the URL is also rejected. +// +// Errors wrap [ErrBlocked], [ErrInvalidURL], or [ErrDNSFailed] for +// programmatic inspection via [errors.Is]. +func ResolveAndValidate(ctx context.Context, rawURL string, opts ...Option) (*ResolveResult, error) { + cfg := buildConfig(opts) + + u, err := url.Parse(rawURL) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrInvalidURL, err) + } + + if err := validateScheme(u.Scheme, cfg); err != nil { + return nil, err + } + + hostname := u.Hostname() + if hostname == "" { + return nil, fmt.Errorf("%w: empty hostname", ErrInvalidURL) + } + + // Hostname-level blocking (localhost, metadata endpoints, dangerous suffixes). + if isBlockedHostnameWithConfig(hostname, cfg) { + return nil, fmt.Errorf("%w: hostname %q is blocked", ErrBlocked, hostname) + } + + // DNS resolution — use custom resolver if provided, otherwise default. + ips, dnsErr := lookupHost(ctx, hostname, cfg) + if dnsErr != nil { + return nil, fmt.Errorf("%w: lookup failed for %s: %w", ErrDNSFailed, hostname, dnsErr) + } + + if len(ips) == 0 { + return nil, fmt.Errorf("%w: no addresses returned for %s", ErrDNSFailed, hostname) + } + + // Validate every resolved IP and find the first safe one. + var firstSafeIP string + + for _, ipStr := range ips { + addr, parseErr := netip.ParseAddr(ipStr) + if parseErr != nil { + // Skip unparseable entries; if none survive we fail below. + continue + } + + if !cfg.allowPrivate && IsBlockedAddr(addr) { + return nil, fmt.Errorf("%w: resolved IP %s is in a blocked range", ErrBlocked, ipStr) + } + + if firstSafeIP == "" { + firstSafeIP = ipStr + } + } + + if firstSafeIP == "" { + return nil, fmt.Errorf("%w: no valid IPs resolved for %s", ErrInvalidURL, hostname) + } + + // Preserve the original authority before rewriting the host. + authority := u.Host + + // Pin to first safe IP — prevents DNS rebinding across retries. + port := u.Port() + + switch { + case port != "": + u.Host = net.JoinHostPort(firstSafeIP, port) + case strings.Contains(firstSafeIP, ":"): + // Bare IPv6 literal must be bracket-wrapped for url.URL.Host. + u.Host = "[" + firstSafeIP + "]" + default: + u.Host = firstSafeIP + } + + return &ResolveResult{ + PinnedURL: u.String(), + Authority: authority, + SNIHostname: hostname, + }, nil +} + +// validateScheme checks that the URL scheme is allowed. By default only "http" +// and "https" are permitted. With [WithHTTPSOnly], only "https" is allowed. +func validateScheme(scheme string, cfg *config) error { + s := strings.ToLower(scheme) + + if cfg.httpsOnly { + if s != "https" { + return fmt.Errorf("%w: scheme %q not allowed (HTTPS only)", ErrBlocked, scheme) + } + + return nil + } + + if s != "http" && s != "https" { + return fmt.Errorf("%w: scheme %q not allowed", ErrBlocked, scheme) + } + + return nil +} + +// lookupHost resolves hostname via the configured lookup function or the +// default resolver. +func lookupHost(ctx context.Context, hostname string, cfg *config) ([]string, error) { + if cfg.lookupFunc != nil { + return cfg.lookupFunc(ctx, hostname) + } + + return net.DefaultResolver.LookupHost(ctx, hostname) +} diff --git a/commons/webhook/ssrf.go b/commons/webhook/ssrf.go index 52649cfb..a5e2eaad 100644 --- a/commons/webhook/ssrf.go +++ b/commons/webhook/ssrf.go @@ -2,170 +2,48 @@ package webhook import ( "context" + "errors" "fmt" - "net" - "net/url" - "strings" -) - -// cidr4 constructs a static *net.IPNet for an IPv4 CIDR. Using a helper keeps -// each entry in the table below to a single readable line. All four octets are -// accepted so future CIDR additions don't require changing the signature. -// -//nolint:unparam // d is currently always 0 for these network addresses; kept for generality. -func cidr4(a, b, c, d byte, ones int) *net.IPNet { - return &net.IPNet{ - IP: net.IPv4(a, b, c, d).To4(), - Mask: net.CIDRMask(ones, 32), - } -} -// cgnatBlock is the CGNAT (Carrier-Grade NAT) range defined by RFC 6598. -// Cloud providers frequently use this range for internal routing, so it must -// be blocked to prevent SSRF via addresses like 100.64.0.1. -// -//nolint:gochecknoglobals // package-level CIDR block is intentional for SSRF protection -var cgnatBlock = cidr4(100, 64, 0, 0, 10) - -// additionalBlockedRanges holds CIDR blocks that are not covered by the -// standard net.IP predicates (IsPrivate, IsLoopback, etc.) but must be -// blocked to prevent SSRF attacks. -// -// All entries are compile-time-constructed net.IPNet literals — no runtime -// string parsing, no init() required, and typos surface as test failures -// rather than startup panics. -// -//nolint:gochecknoglobals // package-level slice is intentional for SSRF protection -var additionalBlockedRanges = []*net.IPNet{ - cidr4(0, 0, 0, 0, 8), // 0.0.0.0/8 — "this network" (RFC 1122 §3.2.1.3) - cidr4(192, 0, 0, 0, 24), // 192.0.0.0/24 — IETF protocol assignments (RFC 6890) - cidr4(192, 0, 2, 0, 24), // 192.0.2.0/24 — TEST-NET-1 documentation (RFC 5737) - cidr4(198, 18, 0, 0, 15), // 198.18.0.0/15 — benchmarking (RFC 2544) - cidr4(198, 51, 100, 0, 24), // 198.51.100.0/24 — TEST-NET-2 documentation (RFC 5737) - cidr4(203, 0, 113, 0, 24), // 203.0.113.0/24 — TEST-NET-3 documentation (RFC 5737) - cidr4(240, 0, 0, 0, 4), // 240.0.0.0/4 — reserved/future use (RFC 1112) -} + libSSRF "github.com/LerianStudio/lib-commons/v4/commons/security/ssrf" +) -// resolveAndValidateIP performs a single DNS lookup for the hostname in rawURL, -// validates every resolved IP against the SSRF blocklist, and returns a new URL -// with the hostname replaced by the first resolved IP (DNS pinning). -// -// Combining validation and pinning into one lookup eliminates the TOCTOU window -// that exists when validateResolvedIP and pinResolvedIP are called sequentially: -// a DNS rebinding attack could change the record between those two calls, causing -// the pinned IP to differ from the validated one. +// resolveAndValidateIP delegates to the canonical [libSSRF.ResolveAndValidate] +// function and maps its result back to the 4-return-value signature expected by +// [deliverToEndpoint]. This keeps the webhook package's internal call-sites +// stable while removing the duplicated SSRF implementation. // // On success it returns: -// - pinnedURL — original URL with the hostname replaced by the first resolved IP. -// - originalAuthority — the original host:port authority (u.Host), suitable for the -// HTTP Host header so explicit non-default ports are preserved. -// - sniHostname — the bare hostname (u.Hostname(), port stripped), suitable for -// TLS SNI / certificate verification. -// -// DNS lookup failures are fail-closed: if the hostname cannot be resolved, the -// URL is rejected. When no resolved IP can be parsed from the DNS response the -// URL is considered unresolvable and an error is returned. +// - pinnedURL — original URL with the hostname replaced by the first safe resolved IP. +// - originalAuthority — the original host:port authority, suitable for the HTTP Host header. +// - sniHostname — the bare hostname (port stripped), suitable for TLS SNI / certificate verification. +// +// Error mapping: +// - [libSSRF.ErrBlocked] → [ErrSSRFBlocked] +// - [libSSRF.ErrDNSFailed] → [ErrSSRFBlocked] +// - [libSSRF.ErrInvalidURL] → [ErrInvalidURL] func resolveAndValidateIP(ctx context.Context, rawURL string) (pinnedURL, originalAuthority, sniHostname string, err error) { - u, err := url.Parse(rawURL) - if err != nil { - return "", "", "", fmt.Errorf("%w: %w", ErrInvalidURL, err) - } - - if err := validateScheme(u.Scheme); err != nil { - return "", "", "", err - } - - host := u.Hostname() - if host == "" { - return "", "", "", fmt.Errorf("%w: empty hostname", ErrInvalidURL) - } - - ips, dnsErr := net.DefaultResolver.LookupHost(ctx, host) - if dnsErr != nil { - return "", "", "", fmt.Errorf("%w: DNS lookup failed for %s: %w", ErrSSRFBlocked, host, dnsErr) - } - - if len(ips) == 0 { - return "", "", "", fmt.Errorf("%w: DNS returned no addresses for %s", ErrSSRFBlocked, host) + result, resolveErr := libSSRF.ResolveAndValidate(ctx, rawURL) + if resolveErr != nil { + return "", "", "", mapSSRFError(resolveErr) } - var firstValidIP string - - for _, ipStr := range ips { - ip := net.ParseIP(ipStr) - if ip == nil { - continue - } - - if isPrivateIP(ip) { - return "", "", "", fmt.Errorf("%w: resolved IP %s is private/loopback", ErrSSRFBlocked, ipStr) - } - - if firstValidIP == "" { - firstValidIP = ipStr - } - } - - if firstValidIP == "" { - return "", "", "", fmt.Errorf("%w: no valid IPs resolved for %s", ErrInvalidURL, host) - } - - // Preserve the original authority (host:port) for the HTTP Host header - // before rewriting u.Host to the pinned IP. - authority := u.Host - - // Pin to first valid resolved IP to prevent DNS rebinding across retries. - port := u.Port() + return result.PinnedURL, result.Authority, result.SNIHostname, nil +} +// mapSSRFError translates sentinel errors from the canonical ssrf package into +// the webhook package's error types so that existing callers (and tests) that +// check [errors.Is] against [ErrSSRFBlocked] / [ErrInvalidURL] continue to +// work without modification. +func mapSSRFError(err error) error { switch { - case port != "": - u.Host = net.JoinHostPort(firstValidIP, port) - case strings.Contains(firstValidIP, ":"): - // Bare IPv6 literal must be bracket-wrapped for url.URL.Host. - u.Host = "[" + firstValidIP + "]" + case errors.Is(err, libSSRF.ErrInvalidURL): + return fmt.Errorf("%w: %w", ErrInvalidURL, err) + case errors.Is(err, libSSRF.ErrBlocked): + return fmt.Errorf("%w: %w", ErrSSRFBlocked, err) + case errors.Is(err, libSSRF.ErrDNSFailed): + return fmt.Errorf("%w: %w", ErrSSRFBlocked, err) default: - u.Host = firstValidIP + return fmt.Errorf("%w: %w", ErrSSRFBlocked, err) } - - return u.String(), authority, host, nil -} - -// validateScheme checks that the URL scheme is http or https. All other -// schemes are rejected to prevent SSRF via exotic protocols (file://, gopher://, etc.). -// This is a pure function with no DNS or I/O dependency, making it safe for -// isolated unit testing. -func validateScheme(scheme string) error { - s := strings.ToLower(scheme) - if s != "http" && s != "https" { - return fmt.Errorf("%w: scheme %q not allowed", ErrSSRFBlocked, scheme) - } - - return nil -} - -// isPrivateIP reports whether ip is in a private, loopback, link-local, -// unspecified, CGNAT, multicast, or other reserved range that must not be -// contacted by webhook delivery (SSRF protection). -// -// In addition to the ranges covered by the standard net.IP predicates, this -// function checks the additionalBlockedRanges slice which covers RFC-defined -// special-purpose blocks not included in Go's net package. -func isPrivateIP(ip net.IP) bool { - if ip.IsLoopback() || - ip.IsPrivate() || - ip.IsLinkLocalUnicast() || - ip.IsLinkLocalMulticast() || - ip.IsMulticast() || - ip.IsUnspecified() || - cgnatBlock.Contains(ip) { - return true - } - - for _, block := range additionalBlockedRanges { - if block.Contains(ip) { - return true - } - } - - return false } diff --git a/commons/webhook/ssrf_test.go b/commons/webhook/ssrf_test.go index b0ea8b4c..a8888745 100644 --- a/commons/webhook/ssrf_test.go +++ b/commons/webhook/ssrf_test.go @@ -11,73 +11,6 @@ import ( "github.com/stretchr/testify/require" ) -// --------------------------------------------------------------------------- -// isPrivateIP — pure function, no DNS, safe to unit-test exhaustively. -// --------------------------------------------------------------------------- - -func TestIsPrivateIP(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - ip string - private bool - }{ - // Loopback - {name: "IPv4 loopback", ip: "127.0.0.1", private: true}, - {name: "IPv6 loopback", ip: "::1", private: true}, - - // RFC 1918 private ranges - {name: "10.0.0.0/8", ip: "10.0.0.1", private: true}, - {name: "10.255.255.255", ip: "10.255.255.255", private: true}, - {name: "172.16.0.0/12", ip: "172.16.0.1", private: true}, - {name: "172.31.255.255", ip: "172.31.255.255", private: true}, - {name: "192.168.0.0/16", ip: "192.168.0.1", private: true}, - {name: "192.168.255.255", ip: "192.168.255.255", private: true}, - - // Link-local unicast - {name: "IPv4 link-local", ip: "169.254.1.1", private: true}, - {name: "IPv6 link-local unicast", ip: "fe80::1", private: true}, - - // Link-local multicast - {name: "IPv4 link-local multicast", ip: "224.0.0.1", private: true}, - - // Unspecified addresses (0.0.0.0 / ::) - {name: "IPv4 unspecified", ip: "0.0.0.0", private: true}, - {name: "IPv6 unspecified", ip: "::", private: true}, - - // CGNAT range (RFC 6598): 100.64.0.0/10 - {name: "CGNAT low end", ip: "100.64.0.1", private: true}, - {name: "CGNAT mid range", ip: "100.100.100.100", private: true}, - {name: "CGNAT high end", ip: "100.127.255.254", private: true}, - // Just below CGNAT range — should NOT be private - {name: "100.63.255.255 not CGNAT", ip: "100.63.255.255", private: false}, - // Just above CGNAT range — should NOT be private - {name: "100.128.0.1 not CGNAT", ip: "100.128.0.1", private: false}, - - // Public IPs — should NOT be private - {name: "Google DNS", ip: "8.8.8.8", private: false}, - {name: "Cloudflare", ip: "1.1.1.1", private: false}, - {name: "Public IPv6", ip: "2001:4860:4860::8888", private: false}, - - // Edge: 172.15.x is NOT private (just below 172.16) - {name: "172.15.255.255 not private", ip: "172.15.255.255", private: false}, - // Edge: 172.32.x is NOT private (just above 172.31) - {name: "172.32.0.1 not private", ip: "172.32.0.1", private: false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ip := net.ParseIP(tt.ip) - assert.NotNil(t, ip, "failed to parse IP: %s", tt.ip) - assert.Equal(t, tt.private, isPrivateIP(ip), - "isPrivateIP(%s) = %v, want %v", tt.ip, isPrivateIP(ip), tt.private) - }) - } -} - // --------------------------------------------------------------------------- // resolveAndValidateIP — URL parsing edge cases only. We do NOT hit real DNS. // --------------------------------------------------------------------------- @@ -128,61 +61,6 @@ func TestResolveAndValidateIP_UnsupportedSchemes(t *testing.T) { } } -func TestValidateScheme_AllowedSchemes(t *testing.T) { - t.Parallel() - - // These schemes should pass the scheme check — tested via the pure - // validateScheme function to avoid any DNS dependency. - tests := []struct { - name string - scheme string - }{ - {name: "http scheme", scheme: "http"}, - {name: "https scheme", scheme: "https"}, - {name: "HTTP uppercase", scheme: "HTTP"}, - {name: "HTTPS uppercase", scheme: "HTTPS"}, - {name: "mixed case Http", scheme: "Http"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - err := validateScheme(tt.scheme) - assert.NoError(t, err, "scheme %q must be allowed", tt.scheme) - }) - } -} - -func TestValidateScheme_BlockedSchemes(t *testing.T) { - t.Parallel() - - // These schemes must be rejected — tested via the pure validateScheme - // function to avoid any DNS dependency. - tests := []struct { - name string - scheme string - }{ - {name: "gopher scheme", scheme: "gopher"}, - {name: "file scheme", scheme: "file"}, - {name: "ftp scheme", scheme: "ftp"}, - {name: "javascript scheme", scheme: "javascript"}, - {name: "data scheme", scheme: "data"}, - {name: "empty scheme", scheme: ""}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - err := validateScheme(tt.scheme) - require.Error(t, err) - assert.ErrorIs(t, err, ErrSSRFBlocked, - "scheme %q must be blocked", tt.scheme) - }) - } -} - // --------------------------------------------------------------------------- // resolveAndValidateIP — URL parsing and scheme blocking (no DNS) // --------------------------------------------------------------------------- @@ -257,46 +135,3 @@ func TestResolveAndValidateIP_PrivateIP(t *testing.T) { "localhost (127.0.0.1) must be blocked as a private/loopback address") } -// --------------------------------------------------------------------------- -// isPrivateIP — additional blocked ranges (RFC-defined special-purpose blocks) -// --------------------------------------------------------------------------- - -// TestResolveAndValidateIP_AllBlockedRanges tests isPrivateIP directly against -// representative IPs from every additional CIDR block listed in ssrf.go: -// -// - 0.0.0.1 → 0.0.0.0/8 "this network" (RFC 1122 §3.2.1.3) -// - 192.0.0.1 → 192.0.0.0/24 IETF protocol assignments (RFC 6890) -// - 192.0.2.1 → 192.0.2.0/24 TEST-NET-1 documentation (RFC 5737) -// - 198.18.0.1 → 198.18.0.0/15 benchmarking (RFC 2544) -// - 198.51.100.1 → 198.51.100.0/24 TEST-NET-2 documentation (RFC 5737) -// - 203.0.113.1 → 203.0.113.0/24 TEST-NET-3 documentation (RFC 5737) -// - 240.0.0.1 → 240.0.0.0/4 reserved / future use (RFC 1112) -// - 239.255.255.255 → multicast (net.IP.IsMulticast) -func TestResolveAndValidateIP_AllBlockedRanges(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - ip string - }{ - {name: "0/8 this-network", ip: "0.0.0.1"}, - {name: "192.0.0/24 IETF", ip: "192.0.0.1"}, - {name: "192.0.2/24 TEST-NET-1", ip: "192.0.2.1"}, - {name: "198.18/15 benchmarking", ip: "198.18.0.1"}, - {name: "198.51.100/24 TEST-NET-2", ip: "198.51.100.1"}, - {name: "203.0.113/24 TEST-NET-3", ip: "203.0.113.1"}, - {name: "240/4 reserved", ip: "240.0.0.1"}, - {name: "multicast 239.255.255.255", ip: "239.255.255.255"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ip := net.ParseIP(tt.ip) - require.NotNil(t, ip, "test setup error: %s is not a valid IP", tt.ip) - assert.True(t, isPrivateIP(ip), - "isPrivateIP(%s) must return true for blocked range %s", tt.ip, tt.name) - }) - } -} From 29eaa8424dfe058ab26a1bce5703d16f8c32bbc2 Mon Sep 17 00:00:00 2001 From: Fred Amaral Date: Mon, 30 Mar 2026 13:09:04 -0300 Subject: [PATCH 6/9] docs: document commons/security/ssrf package in AGENTS.md and CLAUDE.md Add repository shape entry, API invariants section, and other-packages bullet for the new canonical SSRF validation package. --- AGENTS.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index 886b3b75..be8ca8be 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -47,6 +47,7 @@ Resilience and safety: - `commons/assert`: production-safe assertions with telemetry integration and domain predicates - `commons/safe`: panic-free math/regex/slice operations with error returns - `commons/security`: sensitive field detection and handling +- `commons/security/ssrf`: canonical SSRF validation — IP blocking (CIDR blocklist + stdlib predicates), hostname blocking (metadata endpoints, dangerous suffixes), URL validation, DNS-pinned resolution with TOCTOU elimination - `commons/errgroup`: goroutine coordination with panic recovery - `commons/certificate`: thread-safe TLS certificate manager with hot reload, PEM file loading, PKCS#8/PKCS#1/EC key support, and `tls.Config` integration @@ -137,6 +138,17 @@ Build and shell: - TLS integration: `TLSCertificate() tls.Certificate` (returns populated `tls.Certificate` struct); `GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error)` — suitable for assignment to `tls.Config.GetCertificate` for transparent hot-reload. - Package-level helper: `LoadFromFiles(certPath, keyPath string) (*x509.Certificate, crypto.Signer, error)` — validates without touching any manager state, useful for pre-flight checking before calling `Rotate`. +### SSRF validation (`commons/security/ssrf`) + +- `IsBlockedIP(net.IP)` and `IsBlockedAddr(netip.Addr)` for IP-level SSRF blocking. `IsBlockedAddr` is the core; `IsBlockedIP` delegates after conversion. +- `IsBlockedHostname(hostname)` for hostname-level blocking: localhost, cloud metadata endpoints, `.local`/`.internal`/`.cluster.local` suffixes. +- `BlockedPrefixes() []netip.Prefix` returns a copy of the canonical CIDR blocklist (8 ranges: this-network, CGNAT, IETF assignments, TEST-NET-1/2/3, benchmarking, reserved). +- `ValidateURL(ctx, rawURL, opts...)` validates scheme, hostname, and parsed IP without DNS. +- `ResolveAndValidate(ctx, rawURL, opts...) (*ResolveResult, error)` performs DNS-pinned validation. `ResolveResult` has `PinnedURL`, `Authority` (host:port for HTTP Host), `SNIHostname` (for TLS ServerName). +- Functional options: `WithHTTPSOnly()`, `WithAllowPrivateNetwork()`, `WithLookupFunc(fn)`, `WithAllowHostname(hostname)`. +- Sentinel errors: `ErrBlocked`, `ErrInvalidURL`, `ErrDNSFailed`. +- Both `commons/webhook` and `commons/net/http` delegate to this package — it is the single source of truth for SSRF blocking across all Lerian services. + ### Assertions (`commons/assert`) - `New(ctx, logger, component, operation) *Asserter` and return errors instead of panicking. @@ -238,6 +250,7 @@ Build and shell: - **Pointers:** `String()`, `Bool()`, `Time()`, `Int()`, `Int64()`, `Float64()`. - **Cron:** `Parse(expr) (Schedule, error)`; `Schedule.Next(t) (time.Time, error)`. - **Security:** `IsSensitiveField(name)`, `DefaultSensitiveFields()`, `DefaultSensitiveFieldsMap()`. +- **SSRF:** `IsBlockedIP()`, `IsBlockedAddr()`, `IsBlockedHostname()`, `BlockedPrefixes()`, `ValidateURL()`, `ResolveAndValidate()`. Single source of truth for all SSRF protection. Both `webhook` and `net/http` delegate here. - **Transaction:** `BuildIntentPlan()` + `ValidateBalanceEligibility()` + `ApplyPosting()` with typed `IntentPlan`, `Posting`, `LedgerTarget`. `ResolveOperation(pending, isSource, status) (Operation, error)`. - **Constants:** `SanitizeMetricLabel(value) string` for OTEL label safety. - **Certificate:** `NewManager(certPath, keyPath) (*Manager, error)`; `Rotate(cert, key)`, `TLSCertificate()`, `GetCertificateFunc()`; package-level `LoadFromFiles(certPath, keyPath)` for pre-flight validation. From 7afa74e7a2c45fa6b336f232d7a12f9d2a6a777e Mon Sep 17 00:00:00 2001 From: Fred Amaral Date: Mon, 30 Mar 2026 14:48:59 -0300 Subject: [PATCH 7/9] style(security/ssrf): fix comment alignment in CIDR blocklist X-Lerian-Ref: 0x1 --- commons/security/ssrf/ssrf.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/commons/security/ssrf/ssrf.go b/commons/security/ssrf/ssrf.go index 8751feab..6ecf71d8 100644 --- a/commons/security/ssrf/ssrf.go +++ b/commons/security/ssrf/ssrf.go @@ -31,13 +31,13 @@ var ( //nolint:gochecknoglobals // package-level CIDR blocklist is intentional for SSRF protection var blockedPrefixes = []netip.Prefix{ netip.MustParsePrefix("0.0.0.0/8"), // "this network" (RFC 1122 S3.2.1.3) - netip.MustParsePrefix("100.64.0.0/10"), // CGNAT (RFC 6598) - netip.MustParsePrefix("192.0.0.0/24"), // IETF protocol assignments (RFC 6890) - netip.MustParsePrefix("192.0.2.0/24"), // TEST-NET-1 documentation (RFC 5737) - netip.MustParsePrefix("198.18.0.0/15"), // benchmarking (RFC 2544) - netip.MustParsePrefix("198.51.100.0/24"), // TEST-NET-2 documentation (RFC 5737) - netip.MustParsePrefix("203.0.113.0/24"), // TEST-NET-3 documentation (RFC 5737) - netip.MustParsePrefix("240.0.0.0/4"), // reserved / future use (RFC 1112) + netip.MustParsePrefix("100.64.0.0/10"), // CGNAT (RFC 6598) + netip.MustParsePrefix("192.0.0.0/24"), // IETF protocol assignments (RFC 6890) + netip.MustParsePrefix("192.0.2.0/24"), // TEST-NET-1 documentation (RFC 5737) + netip.MustParsePrefix("198.18.0.0/15"), // benchmarking (RFC 2544) + netip.MustParsePrefix("198.51.100.0/24"), // TEST-NET-2 documentation (RFC 5737) + netip.MustParsePrefix("203.0.113.0/24"), // TEST-NET-3 documentation (RFC 5737) + netip.MustParsePrefix("240.0.0.0/4"), // reserved / future use (RFC 1112) } // BlockedPrefixes returns a copy of the canonical CIDR blocklist. The returned From f35e341093e116434fc54e6e6cc1e6d60960b8b8 Mon Sep 17 00:00:00 2001 From: Fred Amaral Date: Mon, 30 Mar 2026 14:49:04 -0300 Subject: [PATCH 8/9] test(webhook): refactor SSRF tests and add mapSSRFError coverage Remove duplicate test functions (InvalidScheme was duplicate of BlockedSchemes, PrivateIP replaced by simpler BlockedHostname). Add dedicated TestMapSSRFError covering all four sentinel translation branches. Upgrade assert.Error to require.Error for fail-fast on nil errors. X-Lerian-Ref: 0x1 --- commons/webhook/ssrf_test.go | 131 ++++++++++++++++------------------- 1 file changed, 60 insertions(+), 71 deletions(-) diff --git a/commons/webhook/ssrf_test.go b/commons/webhook/ssrf_test.go index a8888745..de45f395 100644 --- a/commons/webhook/ssrf_test.go +++ b/commons/webhook/ssrf_test.go @@ -4,39 +4,43 @@ package webhook import ( "context" - "net" + "errors" + "fmt" "testing" + libSSRF "github.com/LerianStudio/lib-commons/v4/commons/security/ssrf" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // --------------------------------------------------------------------------- -// resolveAndValidateIP — URL parsing edge cases only. We do NOT hit real DNS. +// resolveAndValidateIP — URL validation (no real DNS unless noted) // --------------------------------------------------------------------------- -func TestResolveAndValidateIP_InvalidURLMalformed(t *testing.T) { +// TestResolveAndValidateIP_InvalidURL checks that a completely malformed URL +// is rejected with ErrInvalidURL. +func TestResolveAndValidateIP_InvalidURL(t *testing.T) { t.Parallel() - _, _, _, err := resolveAndValidateIP(context.Background(), "://missing-scheme") - assert.Error(t, err) + _, _, _, err := resolveAndValidateIP(context.Background(), "://no-scheme") + require.Error(t, err) assert.ErrorIs(t, err, ErrInvalidURL) } -func TestResolveAndValidateIP_EmptyHostnameHTTP(t *testing.T) { +// TestResolveAndValidateIP_EmptyHostname checks that an empty hostname is +// rejected with ErrInvalidURL and a descriptive message. +func TestResolveAndValidateIP_EmptyHostname(t *testing.T) { t.Parallel() _, _, _, err := resolveAndValidateIP(context.Background(), "http://") - assert.Error(t, err) + require.Error(t, err) assert.ErrorIs(t, err, ErrInvalidURL) assert.Contains(t, err.Error(), "empty hostname") } -// --------------------------------------------------------------------------- -// resolveAndValidateIP — scheme validation (blocks non-HTTP/HTTPS schemes). -// --------------------------------------------------------------------------- - -func TestResolveAndValidateIP_UnsupportedSchemes(t *testing.T) { +// TestResolveAndValidateIP_BlockedSchemes verifies that non-HTTP/HTTPS schemes +// are rejected before DNS lookup. +func TestResolveAndValidateIP_BlockedSchemes(t *testing.T) { t.Parallel() tests := []struct { @@ -55,83 +59,68 @@ func TestResolveAndValidateIP_UnsupportedSchemes(t *testing.T) { t.Parallel() _, _, _, err := resolveAndValidateIP(context.Background(), tt.url) - assert.Error(t, err) - assert.ErrorIs(t, err, ErrSSRFBlocked) + require.Error(t, err) + assert.ErrorIs(t, err, ErrSSRFBlocked, + "non-HTTP/HTTPS scheme must return ErrSSRFBlocked") }) } } +// TestResolveAndValidateIP_BlockedHostname confirms that hostnames rejected by +// hostname-level SSRF validation are mapped to ErrSSRFBlocked. +func TestResolveAndValidateIP_BlockedHostname(t *testing.T) { + t.Parallel() + + _, _, _, err := resolveAndValidateIP(context.Background(), "http://localhost") + require.Error(t, err) + assert.ErrorIs(t, err, ErrSSRFBlocked, + "localhost must be blocked by hostname-level SSRF protection") +} + // --------------------------------------------------------------------------- -// resolveAndValidateIP — URL parsing and scheme blocking (no DNS) +// mapSSRFError — sentinel error translation (fail-closed for unknown errors) // --------------------------------------------------------------------------- -// TestResolveAndValidateIP_InvalidScheme checks that non-HTTP/HTTPS schemes are -// rejected before DNS lookup. -func TestResolveAndValidateIP_InvalidScheme(t *testing.T) { +// TestMapSSRFError verifies that all four branches of mapSSRFError translate +// canonical SSRF sentinel errors into the webhook package's error types, and +// that unrecognized errors are mapped to ErrSSRFBlocked (fail-closed). +func TestMapSSRFError(t *testing.T) { t.Parallel() tests := []struct { - name string - url string + name string + input error + wantIs error }{ - {name: "gopher scheme", url: "gopher://example.com"}, - {name: "file scheme", url: "file:///etc/passwd"}, - {name: "ftp scheme", url: "ftp://example.com/file"}, + { + name: "ErrInvalidURL maps to webhook ErrInvalidURL", + input: fmt.Errorf("validation: %w", libSSRF.ErrInvalidURL), + wantIs: ErrInvalidURL, + }, + { + name: "ErrBlocked maps to webhook ErrSSRFBlocked", + input: fmt.Errorf("blocked: %w", libSSRF.ErrBlocked), + wantIs: ErrSSRFBlocked, + }, + { + name: "ErrDNSFailed maps to webhook ErrSSRFBlocked", + input: fmt.Errorf("dns: %w", libSSRF.ErrDNSFailed), + wantIs: ErrSSRFBlocked, + }, + { + name: "unknown error maps to webhook ErrSSRFBlocked (fail-closed)", + input: errors.New("unexpected internal error"), + wantIs: ErrSSRFBlocked, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, _, _, err := resolveAndValidateIP(context.Background(), tt.url) + err := mapSSRFError(tt.input) require.Error(t, err) - assert.ErrorIs(t, err, ErrSSRFBlocked, - "non-HTTP/HTTPS scheme must return ErrSSRFBlocked") + assert.ErrorIs(t, err, tt.wantIs) }) } } - -// TestResolveAndValidateIP_EmptyHostname checks that an empty hostname is rejected. -func TestResolveAndValidateIP_EmptyHostname(t *testing.T) { - t.Parallel() - - _, _, _, err := resolveAndValidateIP(context.Background(), "http://") - require.Error(t, err) - assert.ErrorIs(t, err, ErrInvalidURL) - assert.Contains(t, err.Error(), "empty hostname") -} - -// TestResolveAndValidateIP_InvalidURL checks that a completely malformed URL -// is rejected with ErrInvalidURL. -func TestResolveAndValidateIP_InvalidURL(t *testing.T) { - t.Parallel() - - _, _, _, err := resolveAndValidateIP(context.Background(), "://no-scheme") - require.Error(t, err) - assert.ErrorIs(t, err, ErrInvalidURL) -} - -// TestResolveAndValidateIP_PrivateIP confirms that hostnames that resolve to a -// loopback/private address are blocked. "localhost" resolves to 127.0.0.1 -// (loopback), which must trigger ErrSSRFBlocked. -// -// This test does perform a real DNS lookup for "localhost". On all POSIX -// systems "localhost" is defined in /etc/hosts as 127.0.0.1, so the lookup -// is local and requires no network access. Environments that strip /etc/hosts -// may see the DNS fall-back path (original URL returned, no error), in which -// case the test is skipped gracefully. -func TestResolveAndValidateIP_PrivateIP(t *testing.T) { - t.Parallel() - - // Probe the local DNS before asserting — skip if localhost doesn't resolve. - addrs, lookupErr := net.LookupHost("localhost") - if lookupErr != nil || len(addrs) == 0 { - t.Skip("localhost DNS lookup failed or returned no results — skipping SSRF private-IP test") - } - - _, _, _, err := resolveAndValidateIP(context.Background(), "http://localhost") - require.Error(t, err) - assert.ErrorIs(t, err, ErrSSRFBlocked, - "localhost (127.0.0.1) must be blocked as a private/loopback address") -} - From ea996211e5ccfbb243c97a6245d26e9c80b9a48e Mon Sep 17 00:00:00 2001 From: Fred Amaral Date: Mon, 30 Mar 2026 15:05:46 -0300 Subject: [PATCH 9/9] fix: apply CodeRabbit auto-fixes for PR #421 --- commons/net/http/idempotency/idempotency.go | 25 ++++++++++++++++++--- commons/security/ssrf/ssrf.go | 6 +++++ commons/security/ssrf/ssrf_test.go | 8 +++++++ commons/security/ssrf/validate.go | 4 ++-- 4 files changed, 38 insertions(+), 5 deletions(-) diff --git a/commons/net/http/idempotency/idempotency.go b/commons/net/http/idempotency/idempotency.go index 96a6f3f6..06265154 100644 --- a/commons/net/http/idempotency/idempotency.go +++ b/commons/net/http/idempotency/idempotency.go @@ -2,6 +2,8 @@ package idempotency import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -152,6 +154,15 @@ func (m *Middleware) Check() fiber.Handler { return m.handle } +// redactKey returns a truncated SHA-256 hash of a Redis key for safe logging. +// Idempotency keys are client-controlled and tenant-scoped, so logging them +// verbatim would emit high-cardinality identifiers and potentially leak tenant +// or client information during incidents. +func redactKey(key string) string { + h := sha256.Sum256([]byte(key)) + return hex.EncodeToString(h[:8]) +} + func (m *Middleware) handle(c *fiber.Ctx) error { // Idempotency only applies to mutating methods. if c.Method() == fiber.MethodGet || c.Method() == fiber.MethodOptions || c.Method() == fiber.MethodHead { @@ -230,12 +241,20 @@ func (m *Middleware) handleDuplicate( // Unexpected Redis error (timeout, connection failure) — fail open. m.logger.Log(ctx, log.LevelWarn, "idempotency: failed to read key state, failing open", - log.String("key", key), log.Err(keyErr), + log.String("key_hash", redactKey(key)), log.Err(keyErr), ) return c.Next() } + // The marker has vanished between the SetNX (which saw it) and this Get. + // This happens when the original request failed and deleted the key, or + // the TTL expired in the narrow window. Fail open so the duplicate can + // be retried rather than returning a false "already processed" response. + if errors.Is(keyErr, redis.Nil) { + return c.Next() + } + // Try to replay the cached response (true idempotency). cached, cacheErr := client.Get(ctx, responseKey).Result() @@ -244,7 +263,7 @@ func (m *Middleware) handleDuplicate( // Unexpected Redis error reading cached response — fail open. m.logger.Log(ctx, log.LevelWarn, "idempotency: failed to read cached response, failing open", - log.String("key", responseKey), log.Err(cacheErr), + log.String("key_hash", redactKey(responseKey)), log.Err(cacheErr), ) return c.Next() @@ -256,7 +275,7 @@ func (m *Middleware) handleDuplicate( // to the generic "already processed" response (fail-open). m.logger.Log(ctx, log.LevelWarn, "idempotency: failed to unmarshal cached response, falling through to generic reply", - log.String("key", responseKey), log.Err(unmarshalErr), + log.String("key_hash", redactKey(responseKey)), log.Err(unmarshalErr), ) } else { // Replay persisted headers first so the caller sees diff --git a/commons/security/ssrf/ssrf.go b/commons/security/ssrf/ssrf.go index 6ecf71d8..47b35293 100644 --- a/commons/security/ssrf/ssrf.go +++ b/commons/security/ssrf/ssrf.go @@ -59,6 +59,12 @@ func BlockedPrefixes() []netip.Prefix { // IsLinkLocalMulticast, IsMulticast, IsUnspecified. // 3. Custom blocklist ([BlockedPrefixes]). func IsBlockedAddr(addr netip.Addr) bool { + // Step 0: the zero-value netip.Addr{} is not a real IP address. Treating + // it as safe would violate the fail-closed principle — block it. + if !addr.IsValid() { + return true + } + // Step 1: unmap IPv4-mapped IPv6 so that ::ffff:10.0.0.1 is treated // identically to 10.0.0.1. addr = addr.Unmap() diff --git a/commons/security/ssrf/ssrf_test.go b/commons/security/ssrf/ssrf_test.go index 127bc628..9b5116f3 100644 --- a/commons/security/ssrf/ssrf_test.go +++ b/commons/security/ssrf/ssrf_test.go @@ -173,6 +173,14 @@ func TestIsBlockedAddr(t *testing.T) { } } +func TestIsBlockedAddr_ZeroValue(t *testing.T) { + t.Parallel() + + // The zero-value netip.Addr{} is invalid and must be blocked (fail-closed). + var zero netip.Addr + assert.True(t, IsBlockedAddr(zero), "IsBlockedAddr(netip.Addr{}) must return true (fail-closed)") +} + // --------------------------------------------------------------------------- // IsBlockedIP — legacy net.IP interface // --------------------------------------------------------------------------- diff --git a/commons/security/ssrf/validate.go b/commons/security/ssrf/validate.go index d1524665..5ac1e888 100644 --- a/commons/security/ssrf/validate.go +++ b/commons/security/ssrf/validate.go @@ -39,7 +39,7 @@ type ResolveResult struct { // Errors wrap [ErrBlocked] or [ErrInvalidURL] for programmatic inspection via // [errors.Is]. func ValidateURL(ctx context.Context, rawURL string, opts ...Option) error { - _ = ctx // reserved for future use (e.g. tracing) + _ = ctx // TODO(ssrf): use ctx for tracing/metrics integration when telemetry is wired in cfg := buildConfig(opts) @@ -83,7 +83,7 @@ func ValidateURL(ctx context.Context, rawURL string, opts ...Option) error { // 1. Parse URL, validate scheme (http/https only, or https-only with [WithHTTPSOnly]). // 2. Check hostname blocking ([IsBlockedHostname]). // 3. Resolve DNS (single lookup via [net.DefaultResolver] or [WithLookupFunc]). -// 4. Validate ALL resolved IPs ([IsBlockedAddr]). +// 4. Validate resolved IPs — reject if ANY IP is blocked ([IsBlockedAddr]). // 5. Pin URL to first safe IP, return [ResolveResult]. // // DNS lookup failures are fail-closed: if the hostname cannot be resolved the