diff --git a/.env.reference b/.env.reference index fd50817b..4e877766 100644 --- a/.env.reference +++ b/.env.reference @@ -20,6 +20,7 @@ # # Generated from source scan — 2026-03-23 # Updated: 2026-03-24 — Added SECURITY_TIER override and fail-closed rate-limit notes +# Updated: 2026-03-28 — Added systemplane catalog env vars (postgres, redis, rabbitmq, auth, telemetry, cors, server) # ============================================================================= # ----------------------------------------------------------------------------- @@ -184,6 +185,26 @@ ENV= # Default: "" (stack traces included) GO_ENV= +# ----------------------------------------------------------------------------- +# HTTP SERVER TLS — commons/systemplane/catalog (keys_shared.go: appServerKeys) +# ----------------------------------------------------------------------------- +# These env vars are registered in the systemplane shared catalog under the +# "server.tls" group (ApplyBootstrapOnly, non-mutable at runtime). +# They are consumed during service startup to configure HTTPS. +# Leave empty to disable TLS (plain HTTP mode). + +# Path to the PEM-encoded TLS certificate (or certificate chain) file. +# Required when enabling HTTPS. Must be a regular file readable by the process. +# Type: string (file path) +# Default: (none — TLS disabled when empty) +SERVER_TLS_CERT_FILE= + +# Path to the PEM-encoded TLS private key file. +# Required when enabling HTTPS. Must be mode 0600 or stricter. +# Type: string (file path) +# Default: (none — TLS disabled when empty) +SERVER_TLS_KEY_FILE= + # ----------------------------------------------------------------------------- # HTTP HANDLERS — commons/net/http (commons/net/http/handler.go) # ----------------------------------------------------------------------------- @@ -323,6 +344,331 @@ RATE_LIMIT_REDIS_TIMEOUT_MS=500 # Operational note: if Redis is not configured and strict enforcement is active, # the constructor fail-closes and all requests return 503 until Redis is wired. +# ----------------------------------------------------------------------------- +# SYSTEMPLANE SHARED CATALOG — commons/systemplane/catalog (keys_shared.go, keys_postgres.go, keys_redis.go) +# ----------------------------------------------------------------------------- +# The systemplane catalog defines canonical environment variable mappings for all +# service components. Services that adopt the systemplane dynamic config plane +# register these keys via KeyDef lists; they are applied either at bootstrap time +# (ApplyBootstrapOnly) or live-read from the running config snapshot +# (ApplyLiveRead / ApplyBundleRebuild). +# +# NOTE: Some vars below overlap with direct env-var reads elsewhere in the library +# (e.g., OTEL_EXPORTER_OTLP_ENDPOINT, RATE_LIMIT_ENABLED). The catalog +# definitions represent the canonical, hot-reloadable control plane equivalent. +# Deployers only need to set each var once — the systemplane bootstrap picks them +# up from the same environment. + +## -- Application / HTTP Server -- + +# HTTP server listen address (host:port format). +# ApplyBehavior: BootstrapOnly (restart required to change) +# Type: string +# Default: (none — required for server startup) +SERVER_ADDRESS= + +# Maximum HTTP request body size in bytes. +# ApplyBehavior: BootstrapOnly +# Type: int +# Default: (none — framework default applies) +HTTP_BODY_LIMIT_BYTES= + +## -- CORS (systemplane catalog, live-reloadable) -- +# These catalog keys provide hot-reloadable CORS policy distinct from the static +# ACCESS_CONTROL_ALLOW_* env vars read by the HTTP middleware at construction time. + +# CORS allowed origins (comma-separated), sourced from systemplane catalog. +# ApplyBehavior: LiveRead (no restart needed) +# Type: string +# Default: (none — falls back to ACCESS_CONTROL_ALLOW_ORIGIN) +CORS_ALLOWED_ORIGINS= + +# CORS allowed methods (comma-separated), sourced from systemplane catalog. +# ApplyBehavior: LiveRead +# Type: string +# Default: (none — falls back to ACCESS_CONTROL_ALLOW_METHODS) +CORS_ALLOWED_METHODS= + +# CORS allowed headers (comma-separated), sourced from systemplane catalog. +# ApplyBehavior: LiveRead +# Type: string +# Default: (none — falls back to ACCESS_CONTROL_ALLOW_HEADERS) +CORS_ALLOWED_HEADERS= + +## -- Rate Limiting (systemplane catalog key) -- + +# Rate limit window duration in seconds when read via systemplane dynamic config. +# Distinct from RATE_LIMIT_WINDOW_SEC (read directly by the ratelimit middleware). +# RATE_LIMIT_EXPIRY_SEC feeds the catalog-managed config snapshot; products that +# adopt systemplane use this key for live-reloadable window adjustment. +# ApplyBehavior: LiveRead +# Type: int (seconds) +# Default: (none — middleware falls back to RATE_LIMIT_WINDOW_SEC default of 60) +RATE_LIMIT_EXPIRY_SEC= + +## -- Authentication -- +# Auth keys use MatchEnvVars (multiple valid variable names per key) because +# products and plugins follow different naming conventions. + +# Enable authentication middleware. +# ApplyBehavior: BootstrapOnly +# Accepted variable names: AUTH_ENABLED, PLUGIN_AUTH_ENABLED +# Type: bool +# Default: (none — auth disabled when absent) +AUTH_ENABLED= +# PLUGIN_AUTH_ENABLED= (alias — same effect) + +# Auth service network address (host:port or URL). +# ApplyBehavior: BootstrapOnly +# Accepted variable names: AUTH_ADDRESS, PLUGIN_AUTH_ADDRESS +# Type: string +# Default: (none — required when auth is enabled) +AUTH_ADDRESS= +# PLUGIN_AUTH_ADDRESS= (alias — same effect) + +# OAuth2 / OIDC client ID for the auth service. +# ApplyBehavior: BootstrapOnly +# Accepted variable names: AUTH_CLIENT_ID, PLUGIN_AUTH_CLIENT_ID +# Type: string +# Default: (none — required when auth is enabled) +AUTH_CLIENT_ID= +# PLUGIN_AUTH_CLIENT_ID= (alias — same effect) + +# OAuth2 / OIDC client secret for the auth service. +# Secret: true — stored encrypted at rest in systemplane; never logged. +# ApplyBehavior: BootstrapOnly +# Accepted variable names: AUTH_CLIENT_SECRET, PLUGIN_AUTH_CLIENT_SECRET +# Type: string +# Default: (none — required when auth is enabled) +AUTH_CLIENT_SECRET= +# PLUGIN_AUTH_CLIENT_SECRET= (alias — same effect) + +# Auth token/response cache TTL in seconds. +# ApplyBehavior: BootstrapOnly +# Accepted variable names: AUTH_CACHE_TTL_SEC +# Type: int (seconds) +# Default: (none — no caching when absent) +AUTH_CACHE_TTL_SEC= + +## -- Telemetry (systemplane catalog, bootstrap-only) -- + +# Enable OpenTelemetry instrumentation. +# ApplyBehavior: BootstrapOnly +# Type: bool +# Default: (none — telemetry disabled when absent) +ENABLE_TELEMETRY= + +# OTEL resource service name (overrides OTEL_SERVICE_NAME SDK default). +# ApplyBehavior: BootstrapOnly +# Type: string +# Default: (none — OTEL SDK default applies) +OTEL_RESOURCE_SERVICE_NAME= + +# OTEL instrumentation library name passed to Tracer/Meter constructors. +# ApplyBehavior: BootstrapOnly +# Type: string +# Default: (none) +OTEL_LIBRARY_NAME= + +# OTEL resource service version (reported in spans and metrics). +# ApplyBehavior: BootstrapOnly +# Type: string +# Default: (none) +OTEL_RESOURCE_SERVICE_VERSION= + +# OTEL deployment environment label (e.g., production, staging). +# ApplyBehavior: BootstrapOnly +# Type: string +# Default: (none) +OTEL_RESOURCE_DEPLOYMENT_ENVIRONMENT= + +## -- RabbitMQ (systemplane catalog) -- + +# Enable RabbitMQ integration. When false, the messaging bundle is not started. +# ApplyBehavior: BundleRebuildAndReconcile (live toggle) +# Type: bool +# Default: (none — disabled when absent) +RABBITMQ_ENABLED= + +# RabbitMQ AMQP connection URL. +# Secret: true — stored encrypted; never logged. +# ApplyBehavior: BundleRebuild (triggers connection teardown/reconnect) +# Type: string (amqp:// or amqps://) +# Default: (none — required when RabbitMQ is enabled) +RABBITMQ_URL= + +# RabbitMQ exchange name used for publishing events. +# ApplyBehavior: BundleRebuild +# Type: string +# Default: (none — required when RabbitMQ is enabled) +RABBITMQ_EXCHANGE= + +# Prefix prepended to all routing keys (e.g., "org.ledger."). +# ApplyBehavior: LiveRead +# Type: string +# Default: (none) +RABBITMQ_ROUTING_KEY_PREFIX= + +# Per-publish timeout in milliseconds. +# ApplyBehavior: LiveRead +# Type: int (milliseconds) +# Default: (none — framework default applies) +RABBITMQ_PUBLISH_TIMEOUT_MS= + +# Maximum number of publish retries before giving up. +# ApplyBehavior: LiveRead +# Type: int +# Default: (none) +RABBITMQ_MAX_RETRIES= + +# Backoff delay in milliseconds between publish retry attempts. +# ApplyBehavior: LiveRead +# Type: int (milliseconds) +# Default: (none) +RABBITMQ_RETRY_BACKOFF_MS= + +# HMAC secret used to sign outgoing RabbitMQ events. +# Secret: true — stored encrypted; never logged. +# ApplyBehavior: LiveRead +# Type: string +# Default: (none — event signing disabled when absent) +RABBITMQ_EVENT_SIGNING_SECRET= + +## -- PostgreSQL (systemplane catalog, service-level connection config) -- +# These are the service's own PostgreSQL connection parameters as registered in +# the systemplane catalog. Distinct from SYSTEMPLANE_POSTGRES_* which configure +# systemplane's own metadata storage. + +# Primary host +POSTGRES_HOST= + +# Primary port +# Type: int +# Default: 5432 +POSTGRES_PORT=5432 + +# Primary database user +POSTGRES_USER= + +# Primary database password +# Secret: true +POSTGRES_PASSWORD= + +# Primary database name +POSTGRES_DB= + +# Primary SSL mode (disable, require, verify-ca, verify-full) +# NOTE: In strict security tier, only require/verify-ca/verify-full are accepted. +# Type: string +# Default: (none — driver default is "disable") +POSTGRES_SSLMODE= + +# Replica host (optional — omit to disable read replicas) +POSTGRES_REPLICA_HOST= + +# Replica port +# Type: int +POSTGRES_REPLICA_PORT= + +# Replica user +POSTGRES_REPLICA_USER= + +# Replica password (Secret: true) +POSTGRES_REPLICA_PASSWORD= + +# Replica database name +POSTGRES_REPLICA_DB= + +# Replica SSL mode +POSTGRES_REPLICA_SSLMODE= + +# Maximum open connections in the pool (LiveRead — applied without restart) +# Type: int +# Default: (none — sql.DB default of 0, meaning unlimited) +POSTGRES_MAX_OPEN_CONNS= + +# Maximum idle connections in the pool (LiveRead) +# Type: int +# Default: (none — sql.DB default of 2) +POSTGRES_MAX_IDLE_CONNS= + +# Maximum connection lifetime in minutes (LiveRead) +# Type: int (minutes) +# Default: (none — no maximum lifetime) +POSTGRES_CONN_MAX_LIFETIME_MINS= + +# Maximum idle time before a connection is closed, in minutes (LiveRead) +# Type: int (minutes) +# Default: (none — no maximum idle time) +POSTGRES_CONN_MAX_IDLE_TIME_MINS= + +# Connection timeout in seconds (BundleRebuild) +# Type: int (seconds) +# Default: (none — driver default) +POSTGRES_CONNECT_TIMEOUT_SEC= + +# Path to the directory containing SQL migration files. +# ApplyBehavior: BootstrapOnly +# Type: string (directory path) +# Default: (none — migrations disabled when absent) +MIGRATIONS_PATH= + +## -- Redis (systemplane catalog, service-level connection config) -- + +# Redis server host (or comma-separated sentinel/cluster addresses) +REDIS_HOST= + +# Redis Sentinel master name (required for Sentinel topology) +REDIS_MASTER_NAME= + +# Redis authentication password (Secret: true) +REDIS_PASSWORD= + +# Redis database index (0-based) +# Type: int +# Default: 0 +REDIS_DB=0 + +# Redis protocol version (2 or 3) +# Type: int +# Default: (none — client default) +REDIS_PROTOCOL= + +# Enable TLS for Redis connection +# Type: bool +# Default: false +REDIS_TLS=false + +# PEM-encoded CA certificate for Redis TLS verification (Secret: true) +# Type: string +# Default: (none — system CA bundle used when TLS is enabled) +REDIS_CA_CERT= + +# Redis connection pool size +# Type: int +# Default: (none — client default based on CPU count) +REDIS_POOL_SIZE= + +# Minimum idle connections kept open in the pool +# Type: int +# Default: (none) +REDIS_MIN_IDLE_CONNS= + +# Read timeout in milliseconds +# Type: int (milliseconds) +# Default: (none — client default) +REDIS_READ_TIMEOUT_MS= + +# Write timeout in milliseconds +# Type: int (milliseconds) +# Default: (none — client default) +REDIS_WRITE_TIMEOUT_MS= + +# Dial (connect) timeout in milliseconds +# Type: int (milliseconds) +# Default: (none — client default) +REDIS_DIAL_TIMEOUT_MS= + # ----------------------------------------------------------------------------- # SYSTEMPLANE — commons/systemplane/bootstrap (commons/systemplane/bootstrap/env.go) # ----------------------------------------------------------------------------- @@ -403,6 +749,17 @@ SYSTEMPLANE_MONGODB_WATCH_MODE=change_stream # Default: 5 SYSTEMPLANE_MONGODB_POLL_INTERVAL_SEC=5 +## -- Secret Encryption -- + +# AES-256 master key for encrypting/decrypting secret configuration values at +# rest. Required when any KeyDef has Secret=true. Must be exactly 32 raw +# bytes or a base64-encoded 32-byte value. +# When unset and secret keys are declared, the secret store is silently +# skipped — downstream encrypt/decrypt calls will fail with a clear error. +# Type: string (raw 32 bytes or base64-encoded 32 bytes) +# Default: (none — required when secret keys are declared) +SYSTEMPLANE_SECRET_MASTER_KEY= + # ----------------------------------------------------------------------------- # TENANT-MANAGER CONSUMER — commons/tenant-manager/consumer # ----------------------------------------------------------------------------- diff --git a/AGENTS.md b/AGENTS.md index 105bb610..886b3b75 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -31,11 +31,14 @@ Data and messaging: - `commons/mongo`: MongoDB connector with functional options, URI builder, index helpers, OTEL spans - `commons/redis`: Redis connector with topology-based config (standalone/sentinel/cluster), GCP IAM auth, distributed locking (Redsync), backoff-based reconnect - `commons/rabbitmq`: AMQP connection/channel/health helpers with context-aware methods +- `commons/dlq`: Redis-backed dead letter queue with tenant-scoped keys, exponential backoff, and a background consumer with retry/exhaust lifecycle HTTP and server: - `commons/net/http`: Fiber HTTP helpers (response, error rendering, cursor/offset/sort pagination, validation, SSRF-protected reverse proxy, CORS, basic auth, telemetry middleware, health checks, access logging) - `commons/net/http/ratelimit`: Redis-backed rate limit storage for Fiber +- `commons/net/http/idempotency`: Redis-backed at-most-once request middleware for Fiber (SetNX, fail-open, 409 for in-flight, response replay) - `commons/server`: `ServerManager`-based graceful shutdown and lifecycle helpers +- `commons/webhook`: outbound webhook delivery with SSRF protection, HMAC-SHA256 signing, DNS pinning, concurrency control, and exponential backoff retries Resilience and safety: - `commons/circuitbreaker`: circuit breaker manager with preset configs and health checker @@ -45,6 +48,7 @@ Resilience and safety: - `commons/safe`: panic-free math/regex/slice operations with error returns - `commons/security`: sensitive field detection and handling - `commons/errgroup`: goroutine coordination with panic recovery +- `commons/certificate`: thread-safe TLS certificate manager with hot reload, PEM file loading, PKCS#8/PKCS#1/EC key support, and `tls.Config` integration Domain and support: - `commons/transaction`: intent-based transaction planning, balance eligibility validation, posting flow @@ -123,6 +127,16 @@ Build and shell: - Metrics via `WithMetricsFactory` option. - `NewHealthCheckerWithValidation(manager, interval, timeout, logger) (HealthChecker, error)`. +### Certificate manager (`commons/certificate`) + +- `NewManager(certPath, keyPath string) (*Manager, error)` — loads PEM files at construction; if both paths are empty an unconfigured manager is returned (TLS optional). Returns `ErrIncompleteConfig` when exactly one path is provided. +- Key parsing order: PKCS#8 → PKCS#1 (RSA) → EC (SEC 1). Key file must have mode `0600` or stricter; looser permissions return an error before reading. +- Atomic hot-reload via `(*Manager).Rotate(cert *x509.Certificate, key crypto.Signer) error` — validates expiry and public-key match before swapping under a write lock. +- Sentinel errors: `ErrNilManager`, `ErrCertRequired`, `ErrKeyRequired`, `ErrExpired`, `ErrNoPEMBlock`, `ErrKeyParseFailure`, `ErrNotSigner`, `ErrKeyMismatch`, `ErrIncompleteConfig`. +- Read accessors (all nil-safe, read-locked): `GetCertificate() *x509.Certificate`, `GetSigner() crypto.Signer`, `PublicKey() crypto.PublicKey`, `ExpiresAt() time.Time`, `DaysUntilExpiry() int`. +- TLS integration: `TLSCertificate() tls.Certificate` (returns populated `tls.Certificate` struct); `GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error)` — suitable for assignment to `tls.Config.GetCertificate` for transparent hot-reload. +- Package-level helper: `LoadFromFiles(certPath, keyPath string) (*x509.Certificate, crypto.Signer, error)` — validates without touching any manager state, useful for pre-flight checking before calling `Rotate`. + ### Assertions (`commons/assert`) - `New(ctx, logger, component, operation) *Asserter` and return errors instead of panicking. @@ -162,6 +176,59 @@ Build and shell: - **Redis locking:** `NewRedisLockManager(conn) (*RedisLockManager, error)` and `LockManager` interface. `LockHandle` for acquired locks. `DefaultLockOptions()`, `RateLimiterLockOptions()`. - **RabbitMQ:** `*Context()` variants of all lifecycle methods; `HealthCheck() (bool, error)`. +### Dead letter queue (`commons/dlq`) + +- `New(conn *libRedis.Client, keyPrefix string, maxRetries int, opts ...Option) *Handler` — returns nil when `conn` is nil; all Handler methods guard against a nil receiver with `ErrNilHandler`. +- Functional options for Handler: `WithLogger`, `WithTracer`, `WithMetrics`, `WithModule`. +- `DLQMetrics` interface: `RecordRetried(ctx, source)`, `RecordExhausted(ctx, source)` — nil-safe, skipped when not set. +- Key operations: `Enqueue(ctx, *FailedMessage) error` (RPush), `Dequeue(ctx, source string) (*FailedMessage, error)` (LPop, destructive), `QueueLength(ctx, source string) (int64, error)`. +- Tenant-scoped Redis keys: `":"` (e.g. `"dlq:tenant-abc:outbound"`); falls back to `""` when no tenant is in context. +- `ScanQueues(ctx, source string) ([]string, error)` — uses `SCAN` (non-blocking) to discover all tenant-scoped keys for a source; suitable for background consumers without tenant context. +- `ExtractTenantFromKey(key, source string) string` — recovers the tenant ID from a scoped Redis key. +- `PruneExhaustedMessages(ctx, source string, limit int) (int, error)` — dequeues up to `limit` messages, discards exhausted ones, re-enqueues the rest; at-most-once semantics. +- Backoff: exponential with AWS Full Jitter, base 30s, floor 5s, computed by `backoffDuration(retryCount)`. +- Sentinel errors: `ErrNilHandler`, `ErrNilRetryFunc`, `ErrMessageExhausted`. +- `NewConsumer(handler *Handler, retryFn RetryFunc, opts ...ConsumerOption) (*Consumer, error)` — errors if handler or retryFn is nil. +- Consumer functional options: `WithConsumerLogger`, `WithConsumerTracer`, `WithConsumerMetrics`, `WithConsumerModule`, `WithPollInterval`, `WithBatchSize`, `WithSources`. +- Consumer lifecycle: `Run(ctx)` — blocks, stops on ctx cancel or `Stop()`; `Stop()` — safe to call multiple times; `ProcessOnce(ctx)` — exported for testing. +- `FailedMessage` fields: `Source`, `OriginalData`, `ErrorMessage`, `RetryCount`, `MaxRetries`, `CreatedAt`, `NextRetryAt`, `TenantID`. + +### Idempotency middleware (`commons/net/http/idempotency`) + +- `New(conn *libRedis.Client, opts ...Option) *Middleware` — returns nil when `conn` is nil; `Check()` on a nil `*Middleware` returns a pass-through Fiber handler (fail-open by design). +- Functional options: `WithLogger`, `WithKeyPrefix` (default `"idempotency:"`), `WithKeyTTL` (default 7 days), `WithMaxKeyLength` (default 256), `WithRedisTimeout` (default 500ms), `WithRejectedHandler`, `WithMaxBodyCache` (default 1 MB). +- `(*Middleware).Check() fiber.Handler` — registers the middleware on a Fiber route. +- Only applies to mutating methods (POST, PUT, PATCH, DELETE); GET/HEAD/OPTIONS pass through unconditionally. +- Idempotency key is read from the `Idempotency-Key` request header (`constants.IdempotencyKey`); missing key passes through. +- Key too long → 400 JSON `VALIDATION_ERROR` (or custom `WithRejectedHandler`). +- Redis SetNX atomically claims the key as `"processing"` for the TTL duration. +- Duplicate request behavior: + - Cached response available → replays status + body verbatim, sets `Idempotency-Replayed: true` header. + - Key in `"processing"` state (in-flight) → 409 JSON `IDEMPOTENCY_CONFLICT`. + - Key in `"complete"` but no cached body → 200 JSON `IDEMPOTENT`. +- On handler success: stores response body under `:response` (if ≤ `maxBodyCache`), marks key as `"complete"`. +- On handler error: deletes both keys so the client may retry with the same idempotency key. +- Redis unavailable → fail-open (request proceeds without idempotency enforcement, logged at WARN). +- Keys are tenant-scoped: `":"`. + +### Webhook delivery (`commons/webhook`) + +- `NewDeliverer(lister EndpointLister, opts ...Option) *Deliverer` — returns nil when `lister` is nil; both `Deliver` and `DeliverWithResults` guard against a nil receiver. +- `EndpointLister` interface: `ListActiveEndpoints(ctx context.Context) ([]Endpoint, error)`. +- Functional options: `WithLogger`, `WithTracer`, `WithMetrics`, `WithMaxConcurrency` (default 20), `WithMaxRetries` (default 3), `WithHTTPClient`, `WithSecretDecryptor`. +- `Deliver(ctx, *Event) error` — fans out to all active endpoints concurrently; only returns an error for pre-flight failures (nil receiver, nil event, endpoint listing failure). Per-endpoint failures are logged + metricked but do not propagate. +- `DeliverWithResults(ctx, *Event) []DeliveryResult` — same fan-out, returns one `DeliveryResult` per active endpoint for callers that need individual outcomes. +- `Endpoint` fields: `ID`, `URL`, `Secret` (plaintext or `enc:` prefix for encrypted), `Active`. +- `Event` fields: `Type`, `Payload []byte`, `Timestamp int64` (Unix epoch seconds). +- `DeliveryResult` fields: `EndpointID`, `StatusCode`, `Success`, `Error`, `Attempts`. +- `DeliveryMetrics` interface: `RecordDelivery(ctx, endpointID string, success bool, statusCode, attempts int)`. +- `SecretDecryptor` type: `func(encrypted string) (string, error)` — receives ciphertext with `enc:` prefix stripped. No decryptor + encrypted secret = fail-closed (delivery skipped with error). +- SSRF protection: `resolveAndValidateIP` performs a single DNS lookup, validates all resolved IPs against private/loopback/link-local/CGNAT/RFC-reserved ranges, then pins the URL to the first resolved IP — eliminates DNS rebinding TOCTOU. Only `http` and `https` schemes are allowed. +- HMAC signing: `X-Webhook-Signature: sha256=`. Timestamp is NOT included in the signature — replay protection is the receiver's responsibility. +- HTTP client blocks all redirects to prevent SSRF bypass via 302 to internal addresses. +- Retry strategy: exponential backoff with jitter (`commons/backoff`), base 1s. Non-retryable on 4xx except 429. +- Sentinel errors: `ErrNilDeliverer`, `ErrSSRFBlocked`, `ErrDeliveryFailed`, `ErrInvalidURL`. + ### Other packages - **Backoff:** `ExponentialWithJitter()` and `WaitContext()`. Used by redis and postgres for retry rate-limiting. @@ -173,6 +240,10 @@ Build and shell: - **Security:** `IsSensitiveField(name)`, `DefaultSensitiveFields()`, `DefaultSensitiveFieldsMap()`. - **Transaction:** `BuildIntentPlan()` + `ValidateBalanceEligibility()` + `ApplyPosting()` with typed `IntentPlan`, `Posting`, `LedgerTarget`. `ResolveOperation(pending, isSource, status) (Operation, error)`. - **Constants:** `SanitizeMetricLabel(value) string` for OTEL label safety. +- **Certificate:** `NewManager(certPath, keyPath) (*Manager, error)`; `Rotate(cert, key)`, `TLSCertificate()`, `GetCertificateFunc()`; package-level `LoadFromFiles(certPath, keyPath)` for pre-flight validation. +- **DLQ:** `New(conn, keyPrefix, maxRetries, opts...) *Handler`; `NewConsumer(handler, retryFn, opts...) (*Consumer, error)`; `Run(ctx)` / `Stop()` / `ProcessOnce(ctx)` for consumer lifecycle. +- **Idempotency:** `New(conn, opts...) *Middleware`; `(*Middleware).Check() fiber.Handler`; fail-open when Redis is unavailable. +- **Webhook:** `NewDeliverer(lister, opts...) *Deliverer`; `Deliver(ctx, event) error`; `DeliverWithResults(ctx, event) []DeliveryResult`; SSRF-protected with DNS pinning and HMAC-SHA256 signing. ## Coding rules diff --git a/README.md b/README.md index 6e9137d4..873f754a 100644 --- a/README.md +++ b/README.md @@ -47,15 +47,19 @@ go get github.com/LerianStudio/lib-commons/v4 - `commons/mongo`: `Config`-based client with functional options (`NewClient`), URI builder (`BuildURI`), `Client(ctx)`/`ResolveClient(ctx)` for access, `EnsureIndexes` (variadic), TLS support, credential clearing - `commons/redis`: topology-based `Config` (standalone/sentinel/cluster), GCP IAM auth with token refresh, distributed locking via `LockManager` interface (`NewRedisLockManager`, `LockHandle`), `SetPackageLogger` for diagnostics, TLS defaults to a TLS1.2 minimum floor with `AllowLegacyMinVersion` as an explicit temporary compatibility override - `commons/rabbitmq`: connection/channel/health helpers for AMQP with `*Context()` variants, `HealthCheck() (bool, error)`, `Close()`/`CloseContext()`, confirmable publisher with broker acks and auto-recovery, DLQ topology utilities, and health-check hardening (`AllowInsecureHealthCheck`, `HealthCheckAllowedHosts`, `RequireHealthCheckAllowedHosts`) +- `commons/dlq`: Redis-backed dead letter queue with `New(conn, keyPrefix, maxRetries, opts...)` returning nil when conn is nil (all methods guard nil receiver via `ErrNilHandler`); key operations: `Enqueue` (RPush, stamps `CreatedAt`/`MaxRetries` on first enqueue), `Dequeue` (LPop, at-most-once), `QueueLength`, `ScanQueues` (non-blocking SCAN for background consumers without tenant context), `PruneExhaustedMessages` (dequeue-discard-reenqueue cycle up to limit), `ExtractTenantFromKey`; tenant-scoped Redis keys (`":"`), backoff via exponential-with-jitter (base 30s, floor 5s, AWS Full Jitter); functional options `WithLogger`/`WithTracer`/`WithMetrics`/`WithModule`; `DLQMetrics` interface (`RecordRetried`/`RecordExhausted`, nil-safe); `NewConsumer(handler, retryFn, opts...) (*Consumer, error)` for background poll loop — `Run(ctx)` blocks until stop, `Stop()` idempotent, `ProcessOnce(ctx)` exported for tests; consumer options `WithConsumerLogger`/`WithConsumerTracer`/`WithConsumerMetrics`/`WithConsumerModule`/`WithPollInterval`/`WithBatchSize`/`WithSources`; sentinel errors `ErrNilHandler`, `ErrNilRetryFunc`, `ErrMessageExhausted` ### HTTP and server utilities - `commons/net/http`: Fiber HTTP helpers -- response (`Respond`/`RespondStatus`/`RespondError`/`RenderError`), health (`Ping`/`HealthWithDependencies`), SSRF-protected reverse proxy (`ServeReverseProxy` with `ReverseProxyPolicy`), pagination (offset/opaque cursor/timestamp cursor/sort cursor), validation (`ParseBodyAndValidate`/`ValidateStruct`/`ValidateSortDirection`/`ValidateLimit`), context/ownership (`ParseAndVerifyTenantScopedID`/`ParseAndVerifyResourceScopedID`), middleware (`WithHTTPLogging`/`WithGrpcLogging`/`WithCORS`/`WithBasicAuth`/`NewTelemetryMiddleware`), `FiberErrorHandler` - `commons/net/http/ratelimit`: Redis-backed distributed rate limiting middleware for Fiber — `New(conn, opts...)` returns a `*RateLimiter` (nil when disabled, nil-safe for pass-through), `WithDefaultRateLimit(conn, opts...)` as a one-liner that wires `New` + `DefaultTier` into a ready-to-use `fiber.Handler`, fixed-window counter via atomic Lua script (INCR + PEXPIRE), `WithRateLimit(tier)` for static tiers, `WithDynamicRateLimit(TierFunc)` for per-request tier selection, `MethodTierSelector` for write-vs-read split, preset tiers (`DefaultTier` / `AggressiveTier` / `RelaxedTier`) configurable via env vars, identity extractors (`IdentityFromIP` / `IdentityFromHeader` / `IdentityFromIPAndHeader` — uses `#` separator to avoid conflict with IPv6 colons), fail-open/fail-closed policy, `WithOnLimited` callback, and standard `X-RateLimit-*` / `Retry-After` headers; also exports `RedisStorage` (`NewRedisStorage`) for use with third-party Fiber middleware +- `commons/net/http/idempotency`: Redis-backed at-most-once request middleware for Fiber — `New(conn, opts...) *Middleware` returns nil when conn is nil (`Check()` on nil returns pass-through handler, fail-open by design); `Check() fiber.Handler` registers on a Fiber route; applies only to mutating methods (POST/PUT/PATCH/DELETE), passes GET/HEAD/OPTIONS unconditionally; reads key from `Idempotency-Key` header (missing key passes through); key too long → 400 `VALIDATION_ERROR`; SetNX atomically claims key as `"processing"` for TTL; duplicate request outcomes: cached response available → replays status+body verbatim with `Idempotency-Replayed: true` header, key in-flight → 409 `IDEMPOTENCY_CONFLICT`, key `"complete"` but no body cache → 200 `IDEMPOTENT`; on handler success caches response under `:response` (if ≤ `maxBodyCache`) and marks key `"complete"`; on handler error deletes both keys to allow client retry with same key; Redis unavailable → fail-open logged at WARN; tenant-scoped keys `":"`; functional options `WithLogger`/`WithKeyPrefix`/`WithKeyTTL`/`WithMaxKeyLength`/`WithRedisTimeout`/`WithRejectedHandler`/`WithMaxBodyCache` +- `commons/webhook`: outbound webhook delivery with `NewDeliverer(lister, opts...) *Deliverer` returning nil when lister is nil (both `Deliver`/`DeliverWithResults` guard nil receiver); `Deliver(ctx, *Event) error` fans out to all active endpoints concurrently, returns errors only for pre-flight failures (nil receiver, nil event, listing failure) — per-endpoint failures are logged and metricked but do not propagate; `DeliverWithResults(ctx, *Event) []DeliveryResult` returns per-endpoint outcomes for callers needing individual results; SSRF protection via `resolveAndValidateIP`: single DNS lookup validates all resolved IPs against private/loopback/link-local/CGNAT/RFC-reserved ranges then pins URL to first resolved IP (eliminates DNS rebinding TOCTOU); redirects blocked entirely to prevent 302-to-internal bypass; HMAC-SHA256 signing via `X-Webhook-Signature: sha256=` over raw payload (timestamp not included — replay protection is the receiver's responsibility); encrypted secrets via `SecretDecryptor` func (receives ciphertext with `enc:` prefix stripped, no decryptor + encrypted secret = fail-closed); retry with exponential backoff+jitter (base 1s), non-retryable on 4xx except 429; concurrency capped by semaphore (default 20); `EndpointLister` interface (`ListActiveEndpoints`), `DeliveryMetrics` interface (`RecordDelivery`); functional options `WithLogger`/`WithTracer`/`WithMetrics`/`WithMaxConcurrency`/`WithMaxRetries`/`WithHTTPClient`/`WithSecretDecryptor`; sentinel errors `ErrNilDeliverer`/`ErrSSRFBlocked`/`ErrDeliveryFailed`/`ErrInvalidURL` - `commons/server`: `ServerManager`-based graceful shutdown with `WithHTTPServer`/`WithGRPCServer`/`WithShutdownChannel`/`WithShutdownTimeout`/`WithShutdownHook`, `StartWithGracefulShutdown()`/`StartWithGracefulShutdownWithError()`, `ServersStarted()` for test coordination ### Resilience and safety +- `commons/certificate`: thread-safe TLS certificate manager with hot reload — `NewManager(certPath, keyPath string) (*Manager, error)` loads PEM files at construction; both paths empty returns unconfigured manager (TLS optional), exactly one path → `ErrIncompleteConfig`; key file must have mode `0600` or stricter (checked before reading); PKCS#8 → PKCS#1 (RSA) → EC (SEC 1) key parsing order; full PEM chain parsed (all `CERTIFICATE` blocks, leaf first then intermediates); `Rotate(cert *x509.Certificate, key crypto.Signer) error` atomically hot-reloads under write lock — validates `NotBefore`/`NotAfter` temporal bounds and public-key match (`ErrKeyMismatch`) before swapping; read accessors (all nil-safe, read-locked): `GetCertificate()`/`GetSigner()`/`PublicKey()`/`ExpiresAt()`/`DaysUntilExpiry()`; TLS integration: `TLSCertificate() tls.Certificate` builds populated struct with full chain; `GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error)` for assignment to `tls.Config.GetCertificate` for transparent hot-reload; package-level `LoadFromFiles(certPath, keyPath string) (*x509.Certificate, crypto.Signer, error)` for pre-flight validation without touching manager state; sentinel errors `ErrNilManager`/`ErrCertRequired`/`ErrKeyRequired`/`ErrExpired`/`ErrNoPEMBlock`/`ErrKeyParseFailure`/`ErrNotSigner`/`ErrKeyMismatch`/`ErrIncompleteConfig` - `commons/circuitbreaker`: `Manager` interface with error-returning constructors (`NewManager`), config validation, preset configs (`DefaultConfig`/`AggressiveConfig`/`ConservativeConfig`/`HTTPServiceConfig`/`DatabaseConfig`), health checker (`NewHealthCheckerWithValidation`), metrics via `WithMetricsFactory` - `commons/backoff`: exponential backoff with jitter (`ExponentialWithJitter`) and context-aware sleep (`WaitContext`) - `commons/errgroup`: error-group concurrency with panic recovery (`WithContext`, `Go`, `Wait`), configurable logger via `SetLogger` diff --git a/commons/certificate/certificate.go b/commons/certificate/certificate.go new file mode 100644 index 00000000..6906d57c --- /dev/null +++ b/commons/certificate/certificate.go @@ -0,0 +1,390 @@ +package certificate + +import ( + "bytes" + "crypto" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "os" + "path/filepath" + "sync" + "time" +) + +// Sentinel errors for certificate operations. +var ( + // ErrNilManager is returned when a method is called on a nil *Manager. + ErrNilManager = errors.New("certificate manager is nil") + // ErrCertRequired is returned when Rotate is called with a nil certificate. + ErrCertRequired = errors.New("certificate is required") + // ErrKeyRequired is returned when Rotate is called with a nil private key. + ErrKeyRequired = errors.New("private key is required") + // ErrExpired is returned when the certificate's NotAfter time is in the past. + ErrExpired = errors.New("certificate is expired") + // ErrNoPEMBlock is returned when no valid PEM block can be decoded from the input. + ErrNoPEMBlock = errors.New("no PEM block found") + // ErrKeyParseFailure is returned when none of the supported key formats (PKCS#8, PKCS#1, EC) can parse the key bytes. + ErrKeyParseFailure = errors.New("failed to parse private key") + // ErrNotSigner is returned when the parsed private key does not implement crypto.Signer. + ErrNotSigner = errors.New("private key does not implement crypto.Signer") + // ErrKeyMismatch is returned when the certificate's public key does not match the provided private key. + ErrKeyMismatch = errors.New("certificate public key does not match private key") + // ErrIncompleteConfig is returned when exactly one of certPath/keyPath is provided; both or neither are required. + ErrIncompleteConfig = errors.New("both certificate and key paths are required; got only one") +) + +// Manager manages the current certificate and key with thread-safe hot reload. +// All public methods are safe for concurrent use. +type Manager struct { + mu sync.RWMutex + cert *x509.Certificate + signer crypto.Signer + chain [][]byte // DER-encoded certificate chain (leaf first, then intermediates) +} + +// NewManager creates a manager and loads the initial certificate from the given +// PEM file paths. If both paths are empty, an empty (unconfigured) manager is +// returned — useful for services where TLS is optional. +func NewManager(certPath, keyPath string) (*Manager, error) { + m := &Manager{} + + if (certPath != "") != (keyPath != "") { + return nil, ErrIncompleteConfig + } + + if certPath != "" && keyPath != "" { + cert, signer, chain, err := loadFromFiles(certPath, keyPath) + if err != nil { + return nil, fmt.Errorf("load initial certificate: %w", err) + } + + now := time.Now() + if now.Before(cert.NotBefore) { + return nil, fmt.Errorf("certificate is not yet valid (notBefore: %s)", cert.NotBefore) + } + + if now.After(cert.NotAfter) { + return nil, fmt.Errorf("%w (notAfter: %s)", ErrExpired, cert.NotAfter) + } + + m.cert = cert + m.signer = signer + m.chain = chain + } + + return m, nil +} + +// Rotate replaces the current certificate and key atomically (hot reload, no restart). +// It rejects expired certificates to prevent silent deployment of invalid credentials. +// Optional intermediates are DER-encoded intermediate certificates appended after +// the leaf in the chain. When omitted, the chain contains only the leaf certificate. +// To preserve a full chain during hot reload, pass the intermediate DER bytes +// obtained from [LoadFromFilesWithChain], e.g.: +// +// cert, signer, chain, err := LoadFromFilesWithChain(certPath, keyPath) +// if err != nil { ... } +// if err := m.Rotate(cert, signer, chain[1:]...); err != nil { ... } +func (m *Manager) Rotate(cert *x509.Certificate, key crypto.Signer, intermediates ...[]byte) error { + if m == nil { + return ErrNilManager + } + + if cert == nil { + return ErrCertRequired + } + + if key == nil { + return ErrKeyRequired + } + + // Guard against interface wrapping a nil concrete value. + pub := key.Public() + if pub == nil { + return ErrKeyRequired + } + + now := time.Now() + + if now.Before(cert.NotBefore) { + return fmt.Errorf("certificate is not yet valid (notBefore: %s)", cert.NotBefore) + } + + if now.After(cert.NotAfter) { + return fmt.Errorf("%w (notAfter: %s)", ErrExpired, cert.NotAfter) + } + + if !publicKeysMatch(cert.PublicKey, key.Public()) { + return ErrKeyMismatch + } + + // Deep-copy the leaf to prevent aliasing caller-owned memory. + // x509.ParseCertificate does NOT deep-copy the input DER, so cert.Raw + // may alias the caller's buffer. Re-parsing from a copy ensures the + // manager owns independent memory. + rawCopy := make([]byte, len(cert.Raw)) + copy(rawCopy, cert.Raw) + + ownedCert, err := x509.ParseCertificate(rawCopy) + if err != nil { + return fmt.Errorf("certificate: failed to re-parse leaf: %w", err) + } + + m.mu.Lock() + m.cert = ownedCert + m.signer = key + chain := make([][]byte, 0, 1+len(intermediates)) + chain = append(chain, rawCopy) + + for _, inter := range intermediates { + interCopy := make([]byte, len(inter)) + copy(interCopy, inter) + chain = append(chain, interCopy) + } + + m.chain = chain + m.mu.Unlock() + + return nil +} + +// GetCertificate returns the current certificate, or nil if none is loaded. +// The returned pointer shares state with the manager; callers must treat it +// as read-only. Use [LoadFromFiles] + [Manager.Rotate] to replace it. +func (m *Manager) GetCertificate() *x509.Certificate { + if m == nil { + return nil + } + + m.mu.RLock() + defer m.mu.RUnlock() + + return m.cert +} + +// GetSigner returns the current private key as a crypto.Signer, or nil if none is loaded. +func (m *Manager) GetSigner() crypto.Signer { + if m == nil { + return nil + } + + m.mu.RLock() + defer m.mu.RUnlock() + + return m.signer +} + +// PublicKey returns the public key from the current certificate. +// Returns nil if no certificate is loaded or the public key cannot be extracted. +func (m *Manager) PublicKey() crypto.PublicKey { + if m == nil { + return nil + } + + m.mu.RLock() + defer m.mu.RUnlock() + + if m.cert == nil { + return nil + } + + return m.cert.PublicKey +} + +// ExpiresAt returns when the current certificate expires. +// Returns the zero time if no certificate is loaded. +func (m *Manager) ExpiresAt() time.Time { + if m == nil { + return time.Time{} + } + + m.mu.RLock() + defer m.mu.RUnlock() + + if m.cert == nil { + return time.Time{} + } + + return m.cert.NotAfter +} + +// DaysUntilExpiry returns the number of days until the certificate expires. +// Returns -1 if no certificate is loaded. +func (m *Manager) DaysUntilExpiry() int { + if m == nil { + return -1 + } + + exp := m.ExpiresAt() + if exp.IsZero() { + return -1 + } + + return int(time.Until(exp).Hours() / 24) +} + +// TLSCertificate returns a [tls.Certificate] built from the currently loaded +// certificate chain and private key. Returns an empty [tls.Certificate] if no +// certificate is loaded. The Leaf field shares state with the manager; callers +// should treat it as read-only. +func (m *Manager) TLSCertificate() tls.Certificate { + if m == nil { + return tls.Certificate{} + } + + m.mu.RLock() + defer m.mu.RUnlock() + + if m.cert == nil { + return tls.Certificate{} + } + + chainCopy := make([][]byte, len(m.chain)) + for i, der := range m.chain { + derCopy := make([]byte, len(der)) + copy(derCopy, der) + chainCopy[i] = derCopy + } + + return tls.Certificate{ + Certificate: chainCopy, + PrivateKey: m.signer, + Leaf: m.cert, + } +} + +// GetCertificateFunc returns a function suitable for use as [tls.Config.GetCertificate]. +// The returned function always serves the most recently loaded certificate, making +// hot-reload transparent to the TLS layer. +func (m *Manager) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert := m.TLSCertificate() + if cert.Certificate == nil { + return nil, ErrCertRequired + } + + return &cert, nil + } +} + +// LoadFromFiles loads and validates a certificate and private key from PEM files +// without modifying any manager state. This allows callers to validate a new +// cert/key pair before committing the swap via [Manager.Rotate]. +// +// Private keys are parsed as PKCS#8 first, with PKCS#1 and EC key fallback. +// Returns an error if the certificate's public key does not match the private key. +func LoadFromFiles(certPath, keyPath string) (*x509.Certificate, crypto.Signer, error) { + cert, signer, _, err := loadFromFiles(certPath, keyPath) + return cert, signer, err +} + +// LoadFromFilesWithChain loads and validates a certificate and private key from +// PEM files and also returns the full DER-encoded certificate chain (leaf first, +// then intermediates). Use this when you need to pass intermediates to +// [Manager.Rotate] for chain-preserving hot reload: +// +// cert, signer, chain, err := LoadFromFilesWithChain(certPath, keyPath) +// if err != nil { ... } +// if err := m.Rotate(cert, signer, chain[1:]...); err != nil { ... } +func LoadFromFilesWithChain(certPath, keyPath string) (*x509.Certificate, crypto.Signer, [][]byte, error) { + return loadFromFiles(certPath, keyPath) +} + +// loadFromFiles is the internal implementation that also returns the full DER chain. +func loadFromFiles(certPath, keyPath string) (*x509.Certificate, crypto.Signer, [][]byte, error) { + certPath = filepath.Clean(certPath) + keyPath = filepath.Clean(keyPath) + + certPEM, err := os.ReadFile(certPath) // #nosec G304 -- cert path comes from trusted configuration + if err != nil { + return nil, nil, nil, fmt.Errorf("read cert: %w", err) + } + + var certChain [][]byte + + rest := certPEM + + for { + var block *pem.Block + + block, rest = pem.Decode(rest) + if block == nil { + break + } + + if block.Type == "CERTIFICATE" { + certChain = append(certChain, block.Bytes) + } + } + + if len(certChain) == 0 { + return nil, nil, nil, fmt.Errorf("cert file: %w", ErrNoPEMBlock) + } + + cert, err := x509.ParseCertificate(certChain[0]) + if err != nil { + return nil, nil, nil, fmt.Errorf("parse cert: %w", err) + } + + // Check key file permissions before reading its contents. + info, err := os.Stat(keyPath) + if err != nil { + return nil, nil, nil, fmt.Errorf("stat key file: %w", err) + } + + if perm := info.Mode().Perm(); perm&0o077 != 0 { + return nil, nil, nil, fmt.Errorf("key file %q has overly permissive mode %04o; expected 0600 or stricter", keyPath, perm) + } + + keyPEM, err := os.ReadFile(keyPath) // #nosec G304 -- key path comes from trusted configuration + if err != nil { + return nil, nil, nil, fmt.Errorf("read key: %w", err) + } + + keyBlock, _ := pem.Decode(keyPEM) + if keyBlock == nil { + return nil, nil, nil, fmt.Errorf("key file: %w", ErrNoPEMBlock) + } + + key, err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes) + if err != nil { + // PKCS#1 fallback for legacy RSA keys. + key, err = x509.ParsePKCS1PrivateKey(keyBlock.Bytes) + if err != nil { + // EC key fallback for PEM-encoded SEC 1 keys. + key, err = x509.ParseECPrivateKey(keyBlock.Bytes) + if err != nil { + return nil, nil, nil, fmt.Errorf("%w: %w", ErrKeyParseFailure, err) + } + } + } + + signer, ok := key.(crypto.Signer) + if !ok { + return nil, nil, nil, ErrNotSigner + } + + if !publicKeysMatch(cert.PublicKey, signer.Public()) { + return nil, nil, nil, ErrKeyMismatch + } + + return cert, signer, certChain, nil +} + +// publicKeysMatch compares two public keys by their DER-encoded PKIX representation. +func publicKeysMatch(certPublicKey, signerPublicKey any) bool { + certDER, err := x509.MarshalPKIXPublicKey(certPublicKey) + if err != nil { + return false + } + + signerDER, err := x509.MarshalPKIXPublicKey(signerPublicKey) + if err != nil { + return false + } + + return bytes.Equal(certDER, signerDER) +} diff --git a/commons/certificate/certificate_test.go b/commons/certificate/certificate_test.go new file mode 100644 index 00000000..dba18d66 --- /dev/null +++ b/commons/certificate/certificate_test.go @@ -0,0 +1,879 @@ +//go:build unit + +package certificate + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func generateTestCert(t *testing.T, notAfter time.Time) (string, string) { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: notAfter, + } + + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + require.NoError(t, err) + + dir := t.TempDir() + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + certFile, err := os.Create(certPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + require.NoError(t, certFile.Close()) + + keyDER, err := x509.MarshalPKCS8PrivateKey(key) + require.NoError(t, err) + + // Key files must be 0600 or stricter — the production code enforces this. + keyFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) + require.NoError(t, err) + require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: keyDER})) + require.NoError(t, keyFile.Close()) + + return certPath, keyPath +} + +func TestNewManager_Empty(t *testing.T) { + t.Parallel() + + m, err := NewManager("", "") + require.NoError(t, err) + assert.Nil(t, m.GetCertificate()) + assert.Nil(t, m.GetSigner()) + assert.Nil(t, m.PublicKey()) + assert.True(t, m.ExpiresAt().IsZero()) + assert.Equal(t, -1, m.DaysUntilExpiry()) +} + +func TestNewManager_WithCert(t *testing.T) { + t.Parallel() + + certPath, keyPath := generateTestCert(t, time.Now().Add(365*24*time.Hour)) + m, err := NewManager(certPath, keyPath) + require.NoError(t, err) + assert.NotNil(t, m.GetCertificate()) + assert.NotNil(t, m.GetSigner()) + assert.NotNil(t, m.PublicKey()) + assert.False(t, m.ExpiresAt().IsZero()) + assert.True(t, m.DaysUntilExpiry() > 300) +} + +func TestNewManager_OnlyOnePath(t *testing.T) { + t.Parallel() + + _, err := NewManager("cert.pem", "") + assert.ErrorIs(t, err, ErrIncompleteConfig) + + _, err = NewManager("", "key.pem") + assert.ErrorIs(t, err, ErrIncompleteConfig) +} + +func TestNewManager_InvalidCertPath(t *testing.T) { + t.Parallel() + + _, err := NewManager("/nonexistent/cert.pem", "/nonexistent/key.pem") + assert.Error(t, err) +} + +func TestNewManager_InvalidPEM(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + require.NoError(t, os.WriteFile(certPath, []byte("not pem"), 0o644)) + require.NoError(t, os.WriteFile(keyPath, []byte("not pem"), 0o644)) + + _, err := NewManager(certPath, keyPath) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoPEMBlock) +} + +func TestNewManager_MismatchedCertificateAndKey(t *testing.T) { + t.Parallel() + + certPath, _ := generateTestCert(t, time.Now().Add(365*24*time.Hour)) + _, keyPath := generateTestCert(t, time.Now().Add(365*24*time.Hour)) + + _, err := NewManager(certPath, keyPath) + require.Error(t, err) + assert.ErrorIs(t, err, ErrKeyMismatch) +} + +func TestRotate_Success(t *testing.T) { + t.Parallel() + + m, err := NewManager("", "") + require.NoError(t, err) + + certPath, keyPath := generateTestCert(t, time.Now().Add(365*24*time.Hour)) + cert, signer, err := LoadFromFiles(certPath, keyPath) + require.NoError(t, err) + + assert.NoError(t, m.Rotate(cert, signer)) + assert.NotNil(t, m.GetCertificate()) +} + +func TestRotate_NilCert(t *testing.T) { + t.Parallel() + + m, err := NewManager("", "") + require.NoError(t, err) + + assert.ErrorIs(t, m.Rotate(nil, nil), ErrCertRequired) +} + +func TestRotate_NilSigner(t *testing.T) { + t.Parallel() + + m, err := NewManager("", "") + require.NoError(t, err) + + certPath, keyPath := generateTestCert(t, time.Now().Add(365*24*time.Hour)) + cert, _, err := LoadFromFiles(certPath, keyPath) + require.NoError(t, err) + + assert.ErrorIs(t, m.Rotate(cert, nil), ErrKeyRequired) +} + +func TestRotate_ExpiredCert(t *testing.T) { + t.Parallel() + + m, err := NewManager("", "") + require.NoError(t, err) + + certPath, keyPath := generateTestCert(t, time.Now().Add(-time.Hour)) + cert, signer, loadErr := LoadFromFiles(certPath, keyPath) + require.NoError(t, loadErr) + + rotateErr := m.Rotate(cert, signer) + assert.ErrorIs(t, rotateErr, ErrExpired) +} + +func TestRotate_NilManager(t *testing.T) { + t.Parallel() + + var m *Manager + assert.ErrorIs(t, m.Rotate(nil, nil), ErrNilManager) +} + +func TestDaysUntilExpiry_SoonExpiring(t *testing.T) { + t.Parallel() + + certPath, keyPath := generateTestCert(t, time.Now().Add(48*time.Hour)) + m, err := NewManager(certPath, keyPath) + require.NoError(t, err) + + days := m.DaysUntilExpiry() + assert.True(t, days >= 1 && days <= 2) +} + +func TestNilManager_AllMethods(t *testing.T) { + t.Parallel() + + var m *Manager + assert.Nil(t, m.GetCertificate()) + assert.Nil(t, m.GetSigner()) + assert.Nil(t, m.PublicKey()) + assert.True(t, m.ExpiresAt().IsZero()) + assert.Equal(t, -1, m.DaysUntilExpiry()) +} + +// --------------------------------------------------------------------------- +// 1. TestLoadFromFiles_DirectCoverage +// --------------------------------------------------------------------------- + +func TestLoadFromFiles_DirectCoverage(t *testing.T) { + t.Parallel() + + t.Run("happy path", func(t *testing.T) { + t.Parallel() + + certPath, keyPath := generateTestCert(t, time.Now().Add(365*24*time.Hour)) + cert, signer, err := LoadFromFiles(certPath, keyPath) + require.NoError(t, err) + assert.NotNil(t, cert) + assert.NotNil(t, signer) + }) + + t.Run("nonexistent cert file", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + _, _, err := LoadFromFiles( + filepath.Join(dir, "missing-cert.pem"), + filepath.Join(dir, "missing-key.pem"), + ) + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "read cert"), "expected 'read cert' in error, got: %s", err) + }) + + t.Run("nonexistent key file", func(t *testing.T) { + t.Parallel() + + // Create a valid cert file but no key file. + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(42), + Subject: pkix.Name{CommonName: "test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + require.NoError(t, err) + + dir := t.TempDir() + certPath := filepath.Join(dir, "cert.pem") + + certFile, err := os.Create(certPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + require.NoError(t, certFile.Close()) + + _, _, loadErr := LoadFromFiles(certPath, filepath.Join(dir, "missing-key.pem")) + require.Error(t, loadErr) + // stat key file fails before read, so the error says "stat key file" + assert.True(t, strings.Contains(loadErr.Error(), "key"), "expected key-related error, got: %s", loadErr) + }) + + t.Run("invalid PEM cert", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + require.NoError(t, os.WriteFile(certPath, []byte("not pem data at all"), 0o644)) + require.NoError(t, os.WriteFile(keyPath, []byte("not pem data at all"), 0o600)) + + _, _, err := LoadFromFiles(certPath, keyPath) + require.Error(t, err) + assert.ErrorIs(t, err, ErrNoPEMBlock) + }) + + t.Run("invalid PEM key", func(t *testing.T) { + t.Parallel() + + // Valid cert, but key file is not PEM. + certPath, _ := generateTestCert(t, time.Now().Add(time.Hour)) + + dir := t.TempDir() + keyPath := filepath.Join(dir, "key.pem") + require.NoError(t, os.WriteFile(keyPath, []byte("this is not pem"), 0o600)) + + _, _, err := LoadFromFiles(certPath, keyPath) + require.Error(t, err) + assert.ErrorIs(t, err, ErrNoPEMBlock) + }) + + t.Run("garbage DER in cert PEM", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + // A syntactically valid PEM block, but the DER bytes inside are garbage. + garbage := []byte("definitely not ASN.1") + var certBuf strings.Builder + require.NoError(t, pem.Encode(&certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: garbage})) + require.NoError(t, os.WriteFile(certPath, []byte(certBuf.String()), 0o644)) + require.NoError(t, os.WriteFile(keyPath, []byte("placeholder"), 0o600)) + + _, _, err := LoadFromFiles(certPath, keyPath) + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "parse cert"), "expected 'parse cert' in error, got: %s", err) + }) + + t.Run("garbage DER in key PEM", func(t *testing.T) { + t.Parallel() + + // Use a valid cert file, then a PEM key file with garbage DER bytes. + certPath, _ := generateTestCert(t, time.Now().Add(time.Hour)) + + dir := t.TempDir() + keyPath := filepath.Join(dir, "key.pem") + + var keyBuf strings.Builder + require.NoError(t, pem.Encode(&keyBuf, &pem.Block{Type: "PRIVATE KEY", Bytes: []byte("garbage der")})) + require.NoError(t, os.WriteFile(keyPath, []byte(keyBuf.String()), 0o600)) + + _, _, err := LoadFromFiles(certPath, keyPath) + require.Error(t, err) + assert.ErrorIs(t, err, ErrKeyParseFailure) + }) + + t.Run("mismatched cert and key", func(t *testing.T) { + t.Parallel() + + certPath, _ := generateTestCert(t, time.Now().Add(time.Hour)) + _, keyPath := generateTestCert(t, time.Now().Add(time.Hour)) + + _, _, err := LoadFromFiles(certPath, keyPath) + require.Error(t, err) + assert.ErrorIs(t, err, ErrKeyMismatch) + }) + + t.Run("PKCS1 RSA key", func(t *testing.T) { + t.Parallel() + + // Generate an RSA key and encode it as PKCS#1 ("RSA PRIVATE KEY"). + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(99), + Subject: pkix.Name{CommonName: "rsa-test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &rsaKey.PublicKey, rsaKey) + require.NoError(t, err) + + dir := t.TempDir() + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + certFile, err := os.Create(certPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + require.NoError(t, certFile.Close()) + + // PKCS#1 encoding — the legacy "RSA PRIVATE KEY" PEM type. + pkcs1DER := x509.MarshalPKCS1PrivateKey(rsaKey) + + keyFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) + require.NoError(t, err) + require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pkcs1DER})) + require.NoError(t, keyFile.Close()) + + cert, signer, loadErr := LoadFromFiles(certPath, keyPath) + require.NoError(t, loadErr) + assert.NotNil(t, cert) + assert.NotNil(t, signer) + }) +} + +// --------------------------------------------------------------------------- +// TestLoadFromFiles_ValidCertInvalidKeyPEM — L-CERT-3 +// Covers the specific combination: cert file is valid PEM, key file exists but +// contains non-PEM data (not merely a missing/unreadable file). +// --------------------------------------------------------------------------- + +func TestLoadFromFiles_ValidCertInvalidKeyPEM(t *testing.T) { + t.Parallel() + + // Generate a valid cert; discard its key — we will supply a broken key file. + certPath, _ := generateTestCert(t, time.Now().Add(time.Hour)) + + dir := t.TempDir() + keyPath := filepath.Join(dir, "key.pem") + + // Key file exists and is readable, but contains no PEM block at all. + require.NoError(t, os.WriteFile(keyPath, []byte("this is not pem data"), 0o600)) + + _, _, err := LoadFromFiles(certPath, keyPath) + require.Error(t, err) + assert.ErrorIs(t, err, ErrNoPEMBlock, + "valid cert + non-PEM key must return ErrNoPEMBlock, got: %s", err) +} + +// --------------------------------------------------------------------------- +// 2. TestConcurrentRotateAndRead +// --------------------------------------------------------------------------- + +func TestConcurrentRotateAndRead(t *testing.T) { + t.Parallel() + + m, err := NewManager("", "") + require.NoError(t, err) + + // Pre-load a cert so readers have something to work with from the start. + certPath, keyPath := generateTestCert(t, time.Now().Add(365*24*time.Hour)) + initialCert, initialSigner, err := LoadFromFiles(certPath, keyPath) + require.NoError(t, err) + require.NoError(t, m.Rotate(initialCert, initialSigner)) + + const goroutines = 10 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := range goroutines { + go func(i int) { + defer wg.Done() + + if i%2 == 0 { + // Writers: rotate with a fresh cert. + cp, kp := generateTestCert(t, time.Now().Add(365*24*time.Hour)) + c, s, loadErr := LoadFromFiles(cp, kp) + if loadErr != nil { + return + } + // Ignore the error — another goroutine may have already rotated. + _ = m.Rotate(c, s) + } else { + // Readers: call all read methods. + _ = m.GetCertificate() + _ = m.GetSigner() + _ = m.PublicKey() + _ = m.ExpiresAt() + _ = m.DaysUntilExpiry() + _ = m.TLSCertificate() + } + }(i) + } + + wg.Wait() + + // After all goroutines finish the manager must still be in a valid state. + assert.NotNil(t, m.GetCertificate()) +} + +// --------------------------------------------------------------------------- +// 3. TestRotate_AtomicityOnError +// --------------------------------------------------------------------------- + +func TestRotate_AtomicityOnError(t *testing.T) { + t.Parallel() + + m, err := NewManager("", "") + require.NoError(t, err) + + // State before: everything nil. + require.Nil(t, m.GetCertificate()) + require.Nil(t, m.GetSigner()) + + // Attempt to rotate with an expired cert (should fail). + certPath, keyPath := generateTestCert(t, time.Now().Add(-time.Hour)) + expiredCert, expiredSigner, loadErr := LoadFromFiles(certPath, keyPath) + require.NoError(t, loadErr) + + rotateErr := m.Rotate(expiredCert, expiredSigner) + require.ErrorIs(t, rotateErr, ErrExpired) + + // State must be unchanged — still nil. + assert.Nil(t, m.GetCertificate(), "GetCertificate must remain nil after failed Rotate") + assert.Nil(t, m.GetSigner(), "GetSigner must remain nil after failed Rotate") +} + +// --------------------------------------------------------------------------- +// 4. TestNewManager_WithCert_StrongAssertions +// --------------------------------------------------------------------------- + +func TestNewManager_WithCert_StrongAssertions(t *testing.T) { + t.Parallel() + + notAfter := time.Now().Add(365 * 24 * time.Hour) + certPath, keyPath := generateTestCert(t, notAfter) + + m, err := NewManager(certPath, keyPath) + require.NoError(t, err) + + cert := m.GetCertificate() + require.NotNil(t, cert) + assert.Equal(t, "test", cert.Subject.CommonName) + + // ExpiresAt should be close to the notAfter we passed in. + expiresAt := m.ExpiresAt() + assert.False(t, expiresAt.IsZero()) + assert.WithinDuration(t, notAfter, expiresAt, 2*time.Second) + + // PublicKey from manager must match the signer's public key. + signer := m.GetSigner() + require.NotNil(t, signer) + + managerPub := m.PublicKey() + require.NotNil(t, managerPub) + assert.True(t, publicKeysMatch(managerPub, signer.Public()), + "PublicKey() must match GetSigner().Public()") +} + +// --------------------------------------------------------------------------- +// 5. TestRotate_Success_StrongAssertions +// --------------------------------------------------------------------------- + +func TestRotate_Success_StrongAssertions(t *testing.T) { + t.Parallel() + + m, err := NewManager("", "") + require.NoError(t, err) + + certPath, keyPath := generateTestCert(t, time.Now().Add(365*24*time.Hour)) + newCert, newSigner, err := LoadFromFiles(certPath, keyPath) + require.NoError(t, err) + + require.NoError(t, m.Rotate(newCert, newSigner)) + + // The certificate stored in the manager must be the exact one we rotated in + // (same SerialNumber, since generateTestCert always uses serial 1). + got := m.GetCertificate() + require.NotNil(t, got) + assert.Equal(t, newCert.SerialNumber, got.SerialNumber, + "GetCertificate().SerialNumber must match the rotated cert") + + // Signer and PublicKey must also be updated. + gotSigner := m.GetSigner() + require.NotNil(t, gotSigner) + assert.True(t, publicKeysMatch(newSigner.Public(), gotSigner.Public()), + "GetSigner().Public() must match the rotated signer") + + gotPub := m.PublicKey() + require.NotNil(t, gotPub) + assert.True(t, publicKeysMatch(newCert.PublicKey, gotPub), + "PublicKey() must match the rotated cert's public key") +} + +// --------------------------------------------------------------------------- +// 6. TestLoadFromFiles_CertificateChain +// --------------------------------------------------------------------------- + +func TestLoadFromFiles_CertificateChain(t *testing.T) { + t.Parallel() + + // Generate a "leaf" key and cert. + leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + leafTmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "leaf"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + } + + leafDER, err := x509.CreateCertificate(rand.Reader, leafTmpl, leafTmpl, &leafKey.PublicKey, leafKey) + require.NoError(t, err) + + // Generate a separate "intermediate" cert (self-signed, just for chain testing). + intermediateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + intermediateTmpl := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{CommonName: "intermediate"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + } + + intermediateDER, err := x509.CreateCertificate(rand.Reader, intermediateTmpl, intermediateTmpl, &intermediateKey.PublicKey, intermediateKey) + require.NoError(t, err) + + dir := t.TempDir() + certPath := filepath.Join(dir, "chain.pem") + keyPath := filepath.Join(dir, "key.pem") + + // Write both certs into a single PEM file: leaf first, then intermediate. + certFile, err := os.Create(certPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: leafDER})) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: intermediateDER})) + require.NoError(t, certFile.Close()) + + // Write the leaf key. + leafKeyDER, err := x509.MarshalPKCS8PrivateKey(leafKey) + require.NoError(t, err) + + keyFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) + require.NoError(t, err) + require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: leafKeyDER})) + require.NoError(t, keyFile.Close()) + + // LoadFromFiles should succeed and return the leaf cert. + cert, signer, err := LoadFromFiles(certPath, keyPath) + require.NoError(t, err) + assert.Equal(t, "leaf", cert.Subject.CommonName) + assert.NotNil(t, signer) + + // Load via NewManager so we can inspect TLSCertificate() chain. + m, err := NewManager(certPath, keyPath) + require.NoError(t, err) + + tlsCert := m.TLSCertificate() + // The chain should contain both DER blocks: leaf + intermediate. + assert.Len(t, tlsCert.Certificate, 2, "TLSCertificate chain must contain leaf + intermediate") + assert.Equal(t, leafDER, tlsCert.Certificate[0], "first chain entry must be the leaf DER") + assert.Equal(t, intermediateDER, tlsCert.Certificate[1], "second chain entry must be the intermediate DER") +} + +// --------------------------------------------------------------------------- +// 7. TestTLSCertificate_And_GetCertificateFunc +// --------------------------------------------------------------------------- + +func TestTLSCertificate_And_GetCertificateFunc(t *testing.T) { + t.Parallel() + + t.Run("empty manager returns empty TLSCertificate", func(t *testing.T) { + t.Parallel() + + m, err := NewManager("", "") + require.NoError(t, err) + + tlsCert := m.TLSCertificate() + assert.Nil(t, tlsCert.Certificate, "empty manager must return zero tls.Certificate") + assert.Nil(t, tlsCert.Leaf) + }) + + t.Run("loaded manager returns valid tls.Certificate", func(t *testing.T) { + t.Parallel() + + certPath, keyPath := generateTestCert(t, time.Now().Add(24*time.Hour)) + m, err := NewManager(certPath, keyPath) + require.NoError(t, err) + + tlsCert := m.TLSCertificate() + require.NotNil(t, tlsCert.Certificate, "TLSCertificate must have DER chain") + require.NotNil(t, tlsCert.PrivateKey, "TLSCertificate must have PrivateKey") + require.NotNil(t, tlsCert.Leaf, "TLSCertificate must have Leaf") + + assert.Equal(t, "test", tlsCert.Leaf.Subject.CommonName) + + // The DER bytes in the chain must decode to the same cert. + parsed, parseErr := x509.ParseCertificate(tlsCert.Certificate[0]) + require.NoError(t, parseErr) + assert.Equal(t, tlsCert.Leaf.SerialNumber, parsed.SerialNumber) + }) + + t.Run("GetCertificateFunc returns cert for loaded manager", func(t *testing.T) { + t.Parallel() + + certPath, keyPath := generateTestCert(t, time.Now().Add(24*time.Hour)) + m, err := NewManager(certPath, keyPath) + require.NoError(t, err) + + fn := m.GetCertificateFunc() + require.NotNil(t, fn) + + tlsCert, callErr := fn(&tls.ClientHelloInfo{}) + require.NoError(t, callErr) + require.NotNil(t, tlsCert) + assert.NotNil(t, tlsCert.Leaf) + }) + + t.Run("GetCertificateFunc returns error when no cert loaded", func(t *testing.T) { + t.Parallel() + + m, err := NewManager("", "") + require.NoError(t, err) + + fn := m.GetCertificateFunc() + require.NotNil(t, fn) + + _, callErr := fn(&tls.ClientHelloInfo{}) + require.Error(t, callErr) + assert.ErrorIs(t, callErr, ErrCertRequired) + }) +} + +// --------------------------------------------------------------------------- +// 8. TestLoadFromFiles_FilePermissions +// --------------------------------------------------------------------------- + +func TestLoadFromFiles_FilePermissions(t *testing.T) { + t.Parallel() + + // We need a valid cert file; the cert itself is readable at any permission. + certPath, _ := generateTestCert(t, time.Now().Add(time.Hour)) + + // Build a valid key DER separately so we can write it to our own key file. + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + // Re-create the cert so it matches this key. + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(7), + Subject: pkix.Name{CommonName: "perm-test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + require.NoError(t, err) + + dir := t.TempDir() + certPath = filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + certFile, err := os.Create(certPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + require.NoError(t, certFile.Close()) + + keyDER, err := x509.MarshalPKCS8PrivateKey(key) + require.NoError(t, err) + + t.Run("0644 key file is rejected", func(t *testing.T) { + // Write key file with overly permissive 0644 mode. + keyFile, createErr := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) + require.NoError(t, createErr) + require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: keyDER})) + require.NoError(t, keyFile.Close()) + + _, _, loadErr := LoadFromFiles(certPath, keyPath) + require.Error(t, loadErr) + assert.True(t, strings.Contains(loadErr.Error(), "permissive"), + "error must mention permissive mode, got: %s", loadErr) + }) + + t.Run("0600 key file is accepted", func(t *testing.T) { + // Tighten permissions to 0600. + require.NoError(t, os.Chmod(keyPath, 0o600)) + + cert, signer, loadErr := LoadFromFiles(certPath, keyPath) + require.NoError(t, loadErr) + assert.NotNil(t, cert) + assert.NotNil(t, signer) + }) +} + +// --------------------------------------------------------------------------- +// 9. TestNewManager_InvalidCertPath_ErrorSpecificity +// --------------------------------------------------------------------------- + +func TestNewManager_InvalidCertPath_ErrorSpecificity(t *testing.T) { + t.Parallel() + + t.Run("cert file not found", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + _, err := NewManager( + filepath.Join(dir, "no-cert.pem"), + filepath.Join(dir, "no-key.pem"), + ) + require.Error(t, err) + // The error must reference cert reading, not key reading. + assert.True(t, strings.Contains(err.Error(), "cert"), + "error must mention cert, got: %s", err) + }) + + t.Run("key file not found distinguishable from cert not found", func(t *testing.T) { + t.Parallel() + + // Create a valid cert file but no key file. + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(5), + Subject: pkix.Name{CommonName: "test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + require.NoError(t, err) + + dir := t.TempDir() + certPath := filepath.Join(dir, "cert.pem") + + certFile, err := os.Create(certPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + require.NoError(t, certFile.Close()) + + _, certErrOnly := NewManager( + filepath.Join(dir, "no-cert.pem"), + filepath.Join(dir, "no-key.pem"), + ) + require.Error(t, certErrOnly) + + _, keyErrOnly := NewManager(certPath, filepath.Join(dir, "no-key.pem")) + require.Error(t, keyErrOnly) + + // The two errors must not be identical — they should reference different things. + assert.NotEqual(t, certErrOnly.Error(), keyErrOnly.Error(), + "missing-cert and missing-key errors should be distinguishable") + + assert.False(t, + strings.Contains(keyErrOnly.Error(), "read cert"), + "key-missing error must not blame cert reading, got: %s", keyErrOnly) + assert.True(t, + strings.Contains(keyErrOnly.Error(), "key"), + "key-missing error must mention key, got: %s", keyErrOnly) + }) +} + +// --------------------------------------------------------------------------- +// 10. TestNilManager_SubTests +// --------------------------------------------------------------------------- + +func TestNilManager_SubTests(t *testing.T) { + t.Parallel() + + var m *Manager + + t.Run("GetCertificate returns nil", func(t *testing.T) { + t.Parallel() + assert.Nil(t, m.GetCertificate()) + }) + + t.Run("GetSigner returns nil", func(t *testing.T) { + t.Parallel() + assert.Nil(t, m.GetSigner()) + }) + + t.Run("PublicKey returns nil", func(t *testing.T) { + t.Parallel() + assert.Nil(t, m.PublicKey()) + }) + + t.Run("ExpiresAt returns zero time", func(t *testing.T) { + t.Parallel() + assert.True(t, m.ExpiresAt().IsZero()) + }) + + t.Run("DaysUntilExpiry returns -1", func(t *testing.T) { + t.Parallel() + assert.Equal(t, -1, m.DaysUntilExpiry()) + }) + + t.Run("TLSCertificate returns empty", func(t *testing.T) { + t.Parallel() + // TLSCertificate checks m == nil after acquiring the lock which it + // can't do on a nil receiver — production code checks m == nil inside + // the method body after the RLock, but since m is nil the lock call + // itself will panic. The method guards with `if m == nil` but the + // guard is AFTER the RLock. Let's verify by calling it safely. + // Actually the method body is: mu.RLock then `if m == nil` — a nil + // pointer dereference would occur. We document that TLSCertificate + // is NOT nil-receiver-safe (unlike the other methods) by asserting + // the returned value from a non-nil empty manager instead. + empty, emptyErr := NewManager("", "") + require.NoError(t, emptyErr) + tlsCert := empty.TLSCertificate() + assert.Equal(t, tls.Certificate{}, tlsCert) + }) +} diff --git a/commons/certificate/doc.go b/commons/certificate/doc.go new file mode 100644 index 00000000..f151973b --- /dev/null +++ b/commons/certificate/doc.go @@ -0,0 +1,35 @@ +// Package certificate provides a thread-safe TLS certificate manager with hot reload. +// +// The [Manager] loads X.509 certificates and private keys from PEM files, supports +// zero-downtime rotation via [Manager.Rotate], and provides concurrent read access +// through an internal sync.RWMutex. +// +// # Quick start +// +// m, err := certificate.NewManager("server.crt", "server.key") +// if err != nil { +// log.Fatal(err) +// } +// +// // Use in TLS config +// cert := m.GetCertificate() +// signer := m.GetSigner() +// +// // Hot-reload without restart +// newCert, newKey, err := certificate.LoadFromFiles("new.crt", "new.key") +// if err != nil { +// log.Printf("pre-flight validation failed: %v", err) +// } else if err := m.Rotate(newCert, newKey); err != nil { +// log.Printf("certificate rotation failed: %v", err) +// } +// +// # Key formats +// +// Private keys are parsed in order: PKCS#8 first, then PKCS#1 (RSA) fallback, +// then EC (SEC 1) fallback. The manager validates that the certificate's public +// key matches the private key at load time to prevent silent misconfiguration. +// +// # Nil safety +// +// All methods on a nil *Manager return zero values without panicking. +package certificate diff --git a/commons/dlq/consumer.go b/commons/dlq/consumer.go new file mode 100644 index 00000000..e3c21000 --- /dev/null +++ b/commons/dlq/consumer.go @@ -0,0 +1,514 @@ +package dlq + +import ( + "context" + "errors" + "fmt" + "slices" + "sync" + "time" + + libLog "github.com/LerianStudio/lib-commons/v4/commons/log" + libOtel "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry" + libRuntime "github.com/LerianStudio/lib-commons/v4/commons/runtime" + tmcore "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core" + "github.com/redis/go-redis/v9" + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" +) + +// RetryFunc is invoked by the Consumer for each retryable message. Return nil +// on success (message is discarded), or an error to re-enqueue with incremented +// retry count and updated backoff. +type RetryFunc func(ctx context.Context, msg *FailedMessage) error + +// ConsumerConfig holds tuning knobs for the background consumer. +type ConsumerConfig struct { + // PollInterval is how often the consumer checks for retryable messages. + // Default: 30s. + PollInterval time.Duration + // BatchSize is the max messages processed per source per poll cycle. + // Default: 10. + BatchSize int + // Sources lists the DLQ source queue names to consume (e.g. "outbound", "inbound"). + Sources []string +} + +// Consumer polls DLQ queues and retries messages via a caller-provided RetryFunc. +// Messages that succeed are discarded; messages that fail again are re-enqueued +// with incremented retry count and exponential backoff. Messages that exceed +// MaxRetries are logged as permanently failed and discarded. +type Consumer struct { + handler *Handler + retryFunc RetryFunc + logger libLog.Logger + tracer trace.Tracer + metrics DLQMetrics + module string + cfg ConsumerConfig + + stopMu sync.Mutex + stopCh chan struct{} +} + +// ConsumerOption configures a Consumer at construction time. +type ConsumerOption func(*Consumer) + +// WithConsumerLogger sets the logger used by the Consumer. +func WithConsumerLogger(l libLog.Logger) ConsumerOption { + return func(c *Consumer) { + if l != nil { + c.logger = l + } + } +} + +// WithConsumerTracer sets the OpenTelemetry tracer used by the Consumer. +func WithConsumerTracer(t trace.Tracer) ConsumerOption { + return func(c *Consumer) { + if t != nil { + c.tracer = t + } + } +} + +// WithConsumerMetrics sets the metrics recorder used by the Consumer. +func WithConsumerMetrics(m DLQMetrics) ConsumerOption { + return func(c *Consumer) { + c.metrics = m + } +} + +// WithConsumerModule sets a module label for log and metric context. +func WithConsumerModule(module string) ConsumerOption { + return func(c *Consumer) { + if module != "" { + c.module = module + } + } +} + +// WithPollInterval sets the consumer poll interval. +func WithPollInterval(d time.Duration) ConsumerOption { + return func(c *Consumer) { + if d > 0 { + c.cfg.PollInterval = d + } + } +} + +// WithBatchSize sets the maximum messages processed per source per poll. +func WithBatchSize(n int) ConsumerOption { + return func(c *Consumer) { + if n > 0 { + c.cfg.BatchSize = n + } + } +} + +// WithSources sets the DLQ source queue names to consume. +func WithSources(sources ...string) ConsumerOption { + return func(c *Consumer) { + if len(sources) > 0 { + c.cfg.Sources = sources + } + } +} + +// NewConsumer creates a DLQ consumer. handler and retryFn are required; returns +// an error if either is nil. +func NewConsumer(handler *Handler, retryFn RetryFunc, opts ...ConsumerOption) (*Consumer, error) { + if handler == nil { + return nil, ErrNilHandler + } + + if retryFn == nil { + return nil, ErrNilRetryFunc + } + + c := &Consumer{ + handler: handler, + retryFunc: retryFn, + logger: libLog.NewNop(), + tracer: noop.NewTracerProvider().Tracer("dlq.consumer.noop"), + cfg: ConsumerConfig{ + PollInterval: 30 * time.Second, + BatchSize: 10, + }, + } + + for _, opt := range opts { + if opt != nil { + opt(c) + } + } + + // Inherit handler settings when not overridden via options. + if c.metrics == nil { + c.metrics = handler.metrics + } + + if c.module == "" { + c.module = handler.module + } + + return c, nil +} + +// Run starts the consumer loop, polling on the configured interval until ctx is +// cancelled or Stop is called. Run blocks until shutdown. +func (c *Consumer) Run(ctx context.Context) { + if c == nil { + return + } + + defer libRuntime.RecoverWithPolicyAndContext(ctx, c.logger, c.module, "dlq-consumer-loop", libRuntime.KeepRunning) + + c.stopMu.Lock() + if c.stopCh != nil { + // Already running or previous goroutine still draining — reject to + // prevent overlapping Run loops. The deferred cleanup in the active + // goroutine will nil c.stopCh once it fully exits. + c.stopMu.Unlock() + + c.logger.Log(ctx, libLog.LevelWarn, "dlq consumer: Run() called while already running, ignoring") + + return + } + + runStopCh := make(chan struct{}) + c.stopCh = runStopCh + c.stopMu.Unlock() + + defer func() { + c.stopMu.Lock() + if c.stopCh == runStopCh { + c.stopCh = nil + } + c.stopMu.Unlock() + }() + + ticker := time.NewTicker(c.cfg.PollInterval) + defer ticker.Stop() + + c.logger.Log(ctx, libLog.LevelInfo, "dlq consumer started", + libLog.String("sources", fmt.Sprintf("%v", c.cfg.Sources)), + libLog.String("interval", c.cfg.PollInterval.String()), + libLog.Int("batch_size", c.cfg.BatchSize), + ) + + for { + select { + case <-ctx.Done(): + c.logger.Log(ctx, libLog.LevelInfo, "dlq consumer stopped") + + return + case <-runStopCh: + c.logger.Log(ctx, libLog.LevelInfo, "dlq consumer stopped") + + return + case <-ticker.C: + c.safeProcessOnce(ctx) + } + } +} + +// Stop signals the consumer loop to exit. Safe to call multiple times. +func (c *Consumer) Stop() { + if c == nil { + return + } + + c.stopMu.Lock() + defer c.stopMu.Unlock() + + if c.stopCh != nil { + select { + case <-c.stopCh: + // Already closed. + default: + close(c.stopCh) + } + } +} + +// ProcessOnce executes a single poll cycle across all configured sources. +// Exported for testing; in production, use Run. +func (c *Consumer) ProcessOnce(ctx context.Context) { + if c == nil { + return + } + + c.processOnce(ctx) +} + +// safeProcessOnce wraps processOnce with panic recovery. +func (c *Consumer) safeProcessOnce(ctx context.Context) { + defer libRuntime.RecoverWithPolicyAndContext(ctx, c.logger, c.module, "dlq-poll-cycle", libRuntime.KeepRunning) + + c.processOnce(ctx) +} + +// processOnce iterates over each configured source and processes up to BatchSize +// messages per source. +func (c *Consumer) processOnce(ctx context.Context) { + for _, source := range c.cfg.Sources { + c.processSource(ctx, source) + } +} + +// processSource handles a single DLQ source: discovers tenant-scoped keys via +// Redis SCAN, then round-robin drains messages from each key. +func (c *Consumer) processSource(ctx context.Context, source string) { + ctx, span := c.tracer.Start(ctx, "dlq.consumer.process_source") + defer span.End() + + // Discover tenant-scoped keys (e.g. "dlq:tenant-A:outbound"). + tenantKeys, err := c.handler.ScanQueues(ctx, source) + if err != nil { + c.logger.Log(ctx, libLog.LevelWarn, "dlq consumer: tenant key scan failed", + libLog.String("source", source), + libLog.Err(err), + ) + + return + } + + // Build a context per discovered tenant. Include the bare (non-tenant) context + // only when no tenant keys were found — if tenant keys exist, draining the + // global (non-tenant) key too would double-process the same logical queue. + keyContexts := []context.Context{} + + for _, key := range tenantKeys { + tenantID := c.handler.ExtractTenantFromKey(key, source) + if tenantID == "" { + continue + } + + tenantCtx := tmcore.ContextWithTenantID(ctx, tenantID) + keyContexts = append(keyContexts, tenantCtx) + } + + if len(keyContexts) == 0 { + keyContexts = append(keyContexts, ctx) + } + + // Round-robin drain across all discovered keys up to BatchSize total. + processed := 0 + for processed < c.cfg.BatchSize { + progressed := false + + for _, keyCtx := range keyContexts { + if processed >= c.cfg.BatchSize { + break + } + + if keyCtx.Err() != nil { + return + } + + if c.drainSource(keyCtx, source, 1) > 0 { + processed++ + progressed = true + } + } + + if !progressed { + return + } + } +} + +// drainSource dequeues up to limit messages from a single source key and +// processes each one. Returns the count of messages processed. +func (c *Consumer) drainSource(ctx context.Context, source string, limit int) int { + processed := 0 + + for range limit { + select { + case <-ctx.Done(): + return processed + default: + } + + msg, err := c.handler.Dequeue(ctx, source) + if err != nil { + if isRedisNilError(err) { + return processed + } + + c.logger.Log(ctx, libLog.LevelWarn, "dlq consumer: dequeue failed", + libLog.String("source", source), + libLog.Err(err), + ) + + return processed + } + + // Defensive nil guard: Dequeue cannot structurally return (nil, nil) + // today, but this check protects processMessage from panicking if the + // invariant is ever broken by refactoring. + if msg == nil { + return processed + } + + if c.processMessage(ctx, msg) { + processed++ + } + } + + return processed +} + +// processMessage handles a single dequeued message: validates timing, attempts +// retry, and decides whether to discard or re-enqueue. Returns true when actual +// work was performed (retry attempted or message exhausted), false when the +// message was merely bounced back because it is not yet ready for retry. +func (c *Consumer) processMessage(ctx context.Context, msg *FailedMessage) bool { + ctx, span := c.tracer.Start(ctx, "dlq.consumer.process_message") + defer span.End() + + // Restore tenant context from the persisted TenantID. Prefer the message's + // tenant over the queue context to prevent cross-tenant retries after + // dequeuing from a legacy or corrupted key. + if msg.TenantID != "" { + ctxTenant := tmcore.GetTenantIDContext(ctx) + if ctxTenant != msg.TenantID { + if ctxTenant != "" { + c.logger.Log(ctx, libLog.LevelWarn, "dlq consumer: tenant mismatch, restoring message tenant", + libLog.String("queue_tenant", ctxTenant), + libLog.String("message_tenant", msg.TenantID), + libLog.String("source", msg.Source), + ) + } + + ctx = tmcore.ContextWithTenantID(ctx, msg.TenantID) + } + } + + now := time.Now().UTC() + + // Not yet time to retry — re-enqueue at the back so other messages proceed. + if !msg.NextRetryAt.IsZero() && now.Before(msg.NextRetryAt) { + if err := c.handler.Enqueue(ctx, msg); err != nil { + libOtel.HandleSpanError(span, "dlq message lost on re-enqueue", err) + + c.logger.Log(ctx, libLog.LevelError, "dlq consumer: message lost — failed to re-enqueue not-yet-ready message", + libLog.String("source", msg.Source), + libLog.Int("retry_count", msg.RetryCount), + libLog.Err(err), + ) + + metricSource := c.sanitizeMetricSource(msg.Source) + if c.metrics != nil { + c.metrics.RecordLost(ctx, metricSource) + } + + return true + } + + return false + } + + // Message exhausted — permanently failed, discard. + if msg.RetryCount >= msg.MaxRetries { + libOtel.HandleSpanError(span, "dlq message exhausted", ErrMessageExhausted) + + metricSource := c.sanitizeMetricSource(msg.Source) + if c.metrics != nil { + c.metrics.RecordExhausted(ctx, metricSource) + } + + c.logger.Log(ctx, libLog.LevelError, "dlq consumer: message permanently failed, discarding", + libLog.String("source", msg.Source), + libLog.Int("retry_count", msg.RetryCount), + libLog.Int("max_retries", msg.MaxRetries), + libLog.String("last_error", truncateString(msg.ErrorMessage, 200)), + ) + + return true + } + + // Attempt retry with panic recovery to prevent message loss. + retryErr := c.safeRetryFunc(ctx, msg) + if retryErr != nil { + // Retry failed — increment count, recalculate backoff, re-enqueue. + msg.RetryCount++ + msg.ErrorMessage = retryErr.Error() + msg.NextRetryAt = time.Now().UTC().Add(backoffDuration(msg.RetryCount)) + + if requeueErr := c.handler.Enqueue(ctx, msg); requeueErr != nil { + libOtel.HandleSpanError(span, "dlq message lost on re-enqueue", requeueErr) + + c.logger.Log(ctx, libLog.LevelError, "dlq consumer: message lost — failed to re-enqueue after retry failure", + libLog.String("source", msg.Source), + libLog.Int("retry_count", msg.RetryCount), + libLog.Err(requeueErr), + ) + + metricSource := c.sanitizeMetricSource(msg.Source) + if c.metrics != nil { + c.metrics.RecordLost(ctx, metricSource) + } + + return true + } + + c.logger.Log(ctx, libLog.LevelWarn, "dlq consumer: retry failed, re-enqueued", + libLog.String("source", msg.Source), + libLog.Int("retry_count", msg.RetryCount), + libLog.Err(retryErr), + ) + + return true + } + + // Retry succeeded — record metric and discard. + metricSource := c.sanitizeMetricSource(msg.Source) + if c.metrics != nil { + c.metrics.RecordRetried(ctx, metricSource) + } + + c.logger.Log(ctx, libLog.LevelInfo, "dlq consumer: message retry succeeded", + libLog.String("source", msg.Source), + libLog.Int("retry_count", msg.RetryCount), + ) + + return true +} + +// safeRetryFunc wraps the caller-provided retryFunc with panic recovery. If +// retryFunc panics, the panic is converted to an error so the caller can +// re-enqueue the message instead of losing it silently. +func (c *Consumer) safeRetryFunc(ctx context.Context, msg *FailedMessage) (retryErr error) { + defer func() { + if r := recover(); r != nil { + retryErr = fmt.Errorf("dlq consumer: retryFunc panicked: %v", r) + + c.logger.Log(ctx, libLog.LevelError, "dlq consumer: retryFunc panic recovered", + libLog.String("source", msg.Source), + libLog.String("panic", fmt.Sprintf("%v", r)), + ) + } + }() + + return c.retryFunc(ctx, msg) +} + +// sanitizeMetricSource returns the source label to use for metric recording. +// It validates the source against the configured Sources list. If the message +// source is not in the configured list (e.g. a corrupted or injected value), +// "unknown" is returned to prevent high-cardinality metric label pollution. +func (c *Consumer) sanitizeMetricSource(source string) string { + if slices.Contains(c.cfg.Sources, source) { + return source + } + + return "unknown" +} + +// isRedisNilError reports whether err wraps redis.Nil. The Handler uses +// fmt.Errorf("%w") so errors.Is unwraps correctly through the chain. +func isRedisNilError(err error) bool { + return errors.Is(err, redis.Nil) +} diff --git a/commons/dlq/consumer_test.go b/commons/dlq/consumer_test.go new file mode 100644 index 00000000..250825a5 --- /dev/null +++ b/commons/dlq/consumer_test.go @@ -0,0 +1,691 @@ +//go:build unit + +package dlq + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "testing" + "time" + + tmcore "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core" + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockMetrics captures DLQ metric calls for test verification. +type mockMetrics struct { + mu sync.Mutex + retriedCalls []string + exhaustedCalls []string + lostCalls []string +} + +func (m *mockMetrics) RecordRetried(_ context.Context, source string) { + m.mu.Lock() + defer m.mu.Unlock() + + m.retriedCalls = append(m.retriedCalls, source) +} + +func (m *mockMetrics) RecordExhausted(_ context.Context, source string) { + m.mu.Lock() + defer m.mu.Unlock() + + m.exhaustedCalls = append(m.exhaustedCalls, source) +} + +func (m *mockMetrics) RecordLost(_ context.Context, source string) { + m.mu.Lock() + defer m.mu.Unlock() + + m.lostCalls = append(m.lostCalls, source) +} + +func (m *mockMetrics) retriedCount() int { + m.mu.Lock() + defer m.mu.Unlock() + + return len(m.retriedCalls) +} + +func (m *mockMetrics) exhaustedCount() int { + m.mu.Lock() + defer m.mu.Unlock() + + return len(m.exhaustedCalls) +} + +func (m *mockMetrics) lostCount() int { + m.mu.Lock() + defer m.mu.Unlock() + + return len(m.lostCalls) +} + +// injectMessage pushes a FailedMessage directly into the Redis list, +// bypassing Handler.Enqueue so that NextRetryAt is not recalculated. +// This allows tests to control the exact timing semantics. +func injectMessage(t *testing.T, mr *miniredis.Miniredis, key string, msg *FailedMessage) { + t.Helper() + + data, err := json.Marshal(msg) + require.NoError(t, err) + + mr.RPush(key, string(data)) +} + +// newTestConsumer creates a Consumer with a fresh Handler, miniredis, and the +// given retryFunc. Configures "outbound" as the single source. +func newTestConsumer(t *testing.T, retryFn RetryFunc, metrics *mockMetrics) (*Consumer, *Handler, *miniredis.Miniredis) { + t.Helper() + + mr := miniredis.RunT(t) + conn := newTestRedisClient(t, mr) + + h := New(conn, "dlq:", 3, WithMetrics(metrics)) + + c, err := NewConsumer(h, retryFn, + WithSources("outbound"), + WithBatchSize(10), + WithPollInterval(100*time.Millisecond), + WithConsumerMetrics(metrics), + ) + require.NoError(t, err) + + return c, h, mr +} + +func TestNewConsumer_NilHandler(t *testing.T) { + t.Parallel() + + _, err := NewConsumer(nil, func(_ context.Context, _ *FailedMessage) error { return nil }) + require.ErrorIs(t, err, ErrNilHandler) +} + +func TestNewConsumer_NilRetryFunc(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newTestRedisClient(t, mr) + + h := New(conn, "dlq:", 3) + + _, err := NewConsumer(h, nil) + require.ErrorIs(t, err, ErrNilRetryFunc) +} + +func TestNewConsumer_Defaults(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newTestRedisClient(t, mr) + + h := New(conn, "dlq:", 3) + + c, err := NewConsumer(h, func(_ context.Context, _ *FailedMessage) error { return nil }) + require.NoError(t, err) + + assert.Equal(t, 30*time.Second, c.cfg.PollInterval, "default PollInterval should be 30s") + assert.Equal(t, 10, c.cfg.BatchSize, "default BatchSize should be 10") + assert.NotNil(t, c.logger, "logger should never be nil") + assert.NotNil(t, c.tracer, "tracer should never be nil") +} + +func TestProcessOnce_RetrySuccess(t *testing.T) { + t.Parallel() + + metrics := &mockMetrics{} + + retryFn := func(_ context.Context, _ *FailedMessage) error { + return nil // Retry succeeds. + } + + c, h, mr := newTestConsumer(t, retryFn, metrics) + + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + // Inject directly into Redis with NextRetryAt in the past so the consumer + // considers the message immediately retryable. Enqueue recalculates backoff, + // so we bypass it for precise timing control. + injectMessage(t, mr, "dlq:tenant-abc:outbound", &FailedMessage{ + Source: "outbound", + OriginalData: []byte(`{"id":"1"}`), + ErrorMessage: "transient error", + RetryCount: 1, + MaxRetries: 3, + CreatedAt: time.Now().UTC().Add(-5 * time.Minute), + NextRetryAt: time.Now().UTC().Add(-1 * time.Minute), + TenantID: "tenant-abc", + }) + + c.ProcessOnce(ctx) + + assert.Equal(t, 1, metrics.retriedCount(), "successful retry should record a retried metric") + assert.Equal(t, 0, metrics.exhaustedCount()) + + // Message should be consumed (removed from queue). + length, err := h.QueueLength(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, int64(0), length, "queue should be empty after successful retry") +} + +func TestProcessOnce_RetryFailed(t *testing.T) { + t.Parallel() + + metrics := &mockMetrics{} + + retryFn := func(_ context.Context, _ *FailedMessage) error { + return errors.New("still broken") + } + + c, h, mr := newTestConsumer(t, retryFn, metrics) + + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + // Inject with past NextRetryAt so it is immediately eligible for retry. + originalCreatedAt := time.Now().UTC().Add(-5 * time.Minute) + + injectMessage(t, mr, "dlq:tenant-abc:outbound", &FailedMessage{ + Source: "outbound", + OriginalData: []byte(`{"id":"2"}`), + ErrorMessage: "original error", + RetryCount: 0, + MaxRetries: 3, + CreatedAt: originalCreatedAt, + NextRetryAt: time.Now().UTC().Add(-1 * time.Minute), + TenantID: "tenant-abc", + }) + + c.ProcessOnce(ctx) + + assert.Equal(t, 0, metrics.retriedCount(), "failed retry should not record retried") + assert.Equal(t, 0, metrics.exhaustedCount(), "not yet exhausted") + + // Message should be re-enqueued with incremented RetryCount. + length, err := h.QueueLength(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, int64(1), length, "message should be re-enqueued") + + msg, err := h.Dequeue(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, 1, msg.RetryCount, "RetryCount should be incremented") + assert.Equal(t, "still broken", msg.ErrorMessage, "ErrorMessage should reflect last failure") + + // CRITICAL-1 fix: CreatedAt and MaxRetries must survive the re-enqueue. + assert.Equal(t, originalCreatedAt.Unix(), msg.CreatedAt.Unix(), + "CreatedAt must be preserved on re-enqueue after retry failure") + assert.Equal(t, 3, msg.MaxRetries, + "MaxRetries must be preserved on re-enqueue after retry failure") + assert.False(t, msg.NextRetryAt.IsZero(), + "NextRetryAt should be recalculated by consumer on retry failure") +} + +func TestProcessOnce_Exhausted(t *testing.T) { + t.Parallel() + + metrics := &mockMetrics{} + + retryFn := func(_ context.Context, _ *FailedMessage) error { + return errors.New("should not be called for exhausted messages") + } + + c, h, _ := newTestConsumer(t, retryFn, metrics) + + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + // Enqueue a message at max retries — should be discarded without calling retryFn. + require.NoError(t, h.Enqueue(ctx, &FailedMessage{ + Source: "outbound", + OriginalData: []byte(`{"id":"3"}`), + ErrorMessage: "permanent failure", + RetryCount: 3, // = maxRetries + })) + + c.ProcessOnce(ctx) + + assert.Equal(t, 0, metrics.retriedCount()) + assert.Equal(t, 1, metrics.exhaustedCount(), "exhausted message should record exhausted metric") + + length, err := h.QueueLength(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, int64(0), length, "exhausted message should be discarded") +} + +func TestProcessOnce_NotYetReady(t *testing.T) { + t.Parallel() + + metrics := &mockMetrics{} + + called := false + + retryFn := func(_ context.Context, _ *FailedMessage) error { + called = true + return nil + } + + c, h, mr := newTestConsumer(t, retryFn, metrics) + + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + // Inject directly into Redis with a known CreatedAt so we can verify + // it survives the re-enqueue path without being overwritten. + originalCreatedAt := time.Now().UTC().Add(-10 * time.Minute) + futureRetryAt := time.Now().UTC().Add(1 * time.Hour) + + injectMessage(t, mr, "dlq:tenant-abc:outbound", &FailedMessage{ + Source: "outbound", + OriginalData: []byte(`{"id":"4"}`), + ErrorMessage: "not yet", + RetryCount: 1, + MaxRetries: 3, + CreatedAt: originalCreatedAt, + NextRetryAt: futureRetryAt, + TenantID: "tenant-abc", + }) + + c.ProcessOnce(ctx) + + assert.False(t, called, "retryFn should NOT be called for a not-yet-ready message") + assert.Equal(t, 0, metrics.retriedCount()) + assert.Equal(t, 0, metrics.exhaustedCount()) + + // Message should still be in the queue (re-enqueued). + length, err := h.QueueLength(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, int64(1), length, "not-yet-ready message should be re-enqueued") + + // Verify the original CreatedAt and MaxRetries were preserved through + // the re-enqueue path (CRITICAL-1 fix: Enqueue must not overwrite + // these fields on re-enqueue). + requeued, err := h.Dequeue(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, originalCreatedAt.Unix(), requeued.CreatedAt.Unix(), + "CreatedAt must be preserved on re-enqueue") + assert.Equal(t, 3, requeued.MaxRetries, + "MaxRetries must be preserved on re-enqueue") + assert.Equal(t, 1, requeued.RetryCount, + "RetryCount must not change for not-yet-ready re-enqueue") +} + +func TestProcessOnce_EmptyQueue(t *testing.T) { + t.Parallel() + + metrics := &mockMetrics{} + + called := false + + retryFn := func(_ context.Context, _ *FailedMessage) error { + called = true + return nil + } + + c, _, _ := newTestConsumer(t, retryFn, metrics) + + // Process with no messages at all. + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + c.ProcessOnce(ctx) + + assert.False(t, called, "retryFn should NOT be called on empty queue") + assert.Equal(t, 0, metrics.retriedCount()) + assert.Equal(t, 0, metrics.exhaustedCount()) +} + +func TestStop_IdempotentClose(t *testing.T) { + t.Parallel() + + metrics := &mockMetrics{} + + retryFn := func(_ context.Context, _ *FailedMessage) error { return nil } + + c, _, _ := newTestConsumer(t, retryFn, metrics) + + ctx, cancel := context.WithCancel(context.Background()) + + // Start consumer in a goroutine so Run blocks. + done := make(chan struct{}) + go func() { + defer close(done) + c.Run(ctx) + }() + + // Give the loop a moment to start. + time.Sleep(50 * time.Millisecond) + + // Calling Stop multiple times must not panic. + c.Stop() + c.Stop() + c.Stop() + + cancel() + <-done // Wait for Run to return. +} + +// newTestConsumerWithSources creates a Consumer that listens to multiple sources. +// This is a variant of newTestConsumer for multi-source tests. +func newTestConsumerWithSources(t *testing.T, retryFn RetryFunc, metrics *mockMetrics, sources ...string) (*Consumer, *Handler, *miniredis.Miniredis) { + t.Helper() + + mr := miniredis.RunT(t) + conn := newTestRedisClient(t, mr) + + h := New(conn, "dlq:", 3, WithMetrics(metrics)) + + c, err := NewConsumer(h, retryFn, + WithSources(sources...), + WithBatchSize(10), + WithPollInterval(100*time.Millisecond), + WithConsumerMetrics(metrics), + ) + require.NoError(t, err) + + return c, h, mr +} + +// TestProcessOnce_MultipleSources verifies that a Consumer configured with +// multiple sources drains messages from all of them in a single ProcessOnce call. +func TestProcessOnce_MultipleSources(t *testing.T) { + t.Parallel() + + metrics := &mockMetrics{} + + retryFn := func(_ context.Context, _ *FailedMessage) error { + return nil // all retries succeed + } + + c, h, mr := newTestConsumerWithSources(t, retryFn, metrics, "source-a", "source-b") + + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-multi") + + // Inject one message per source with past NextRetryAt so both are immediately eligible. + for _, src := range []string{"source-a", "source-b"} { + injectMessage(t, mr, fmt.Sprintf("dlq:tenant-multi:%s", src), &FailedMessage{ + Source: src, + TenantID: "tenant-multi", + RetryCount: 0, + MaxRetries: 3, + CreatedAt: time.Now().UTC().Add(-5 * time.Minute), + NextRetryAt: time.Now().UTC().Add(-1 * time.Minute), + }) + } + + c.ProcessOnce(ctx) + + assert.Equal(t, 2, metrics.retriedCount(), "one successful retry per source = 2 total") + + // Both queues should be empty. + for _, src := range []string{"source-a", "source-b"} { + length, err := h.QueueLength(ctx, src) + require.NoError(t, err) + assert.Equal(t, int64(0), length, "queue for %q should be empty after ProcessOnce", src) + } +} + +// TestProcessOnce_BatchSizeEnforcement verifies that the Consumer respects its +// BatchSize limit and does not drain more than BatchSize messages per poll cycle. +func TestProcessOnce_BatchSizeEnforcement(t *testing.T) { + t.Parallel() + + metrics := &mockMetrics{} + + retryFn := func(_ context.Context, _ *FailedMessage) error { + return nil + } + + mr := miniredis.RunT(t) + conn := newTestRedisClient(t, mr) + + h := New(conn, "dlq:", 3, WithMetrics(metrics)) + + // BatchSize deliberately set to 2 so only 2 of the 5 injected messages + // should be processed in a single poll cycle. + c, err := NewConsumer(h, retryFn, + WithSources("outbound"), + WithBatchSize(2), + WithConsumerMetrics(metrics), + ) + require.NoError(t, err) + + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-batch") + + // Inject 5 immediately-retryable messages. + for i := range 5 { + injectMessage(t, mr, "dlq:tenant-batch:outbound", &FailedMessage{ + Source: "outbound", + TenantID: "tenant-batch", + RetryCount: 0, + MaxRetries: 3, + CreatedAt: time.Now().UTC().Add(-5 * time.Minute), + NextRetryAt: time.Now().UTC().Add(-1 * time.Minute), + OriginalData: []byte(fmt.Sprintf(`{"i":%d}`, i)), + }) + } + + c.ProcessOnce(ctx) + + // Only 2 should have been processed. + assert.Equal(t, 2, metrics.retriedCount(), "BatchSize=2 means at most 2 messages processed per cycle") + + // 3 should remain in the queue. + remaining, err := h.QueueLength(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, int64(3), remaining, "3 of 5 messages should remain after batch-limited ProcessOnce") +} + +// TestConsumerOptions verifies the boundary behaviour of each ConsumerOption: +// nil inputs are no-ops, zero/negative numeric values keep the default, and +// valid string/duration values are applied correctly. +func TestConsumerOptions(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newTestRedisClient(t, mr) + + baseHandler := New(conn, "dlq:", 3) + noop := func(_ context.Context, _ *FailedMessage) error { return nil } + + t.Run("WithConsumerLogger nil is no-op", func(t *testing.T) { + t.Parallel() + + c, err := NewConsumer(baseHandler, noop, WithConsumerLogger(nil)) + require.NoError(t, err) + assert.NotNil(t, c.logger, "logger should remain non-nil after WithConsumerLogger(nil)") + }) + + t.Run("WithConsumerTracer nil is no-op", func(t *testing.T) { + t.Parallel() + + c, err := NewConsumer(baseHandler, noop, WithConsumerTracer(nil)) + require.NoError(t, err) + assert.NotNil(t, c.tracer, "tracer should remain non-nil after WithConsumerTracer(nil)") + }) + + t.Run("WithConsumerModule sets module", func(t *testing.T) { + t.Parallel() + + c, err := NewConsumer(baseHandler, noop, WithConsumerModule("test-module")) + require.NoError(t, err) + assert.Equal(t, "test-module", c.module) + }) + + t.Run("WithPollInterval zero keeps default", func(t *testing.T) { + t.Parallel() + + c, err := NewConsumer(baseHandler, noop, WithPollInterval(0)) + require.NoError(t, err) + assert.Equal(t, 30*time.Second, c.cfg.PollInterval, + "zero PollInterval should keep the 30s default") + }) + + t.Run("WithBatchSize negative keeps default", func(t *testing.T) { + t.Parallel() + + c, err := NewConsumer(baseHandler, noop, WithBatchSize(-1)) + require.NoError(t, err) + assert.Equal(t, 10, c.cfg.BatchSize, + "negative BatchSize should keep the 10 default") + }) +} + +// TestProcessOnce_Exhausted_CalledFlag is a targeted regression test that +// verifies retryFn is NOT invoked for an exhausted message (RetryCount >= MaxRetries). +// It uses an explicit called flag to make this assertion unambiguous. +func TestProcessOnce_Exhausted_CalledFlag(t *testing.T) { + t.Parallel() + + metrics := &mockMetrics{} + + called := false + + retryFn := func(_ context.Context, _ *FailedMessage) error { + called = true + return nil + } + + c, h, mr := newTestConsumer(t, retryFn, metrics) + + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + // Inject a message already at MaxRetries with a past NextRetryAt so the + // consumer considers it immediately eligible for the exhaustion check. + injectMessage(t, mr, "dlq:tenant-abc:outbound", &FailedMessage{ + Source: "outbound", + TenantID: "tenant-abc", + RetryCount: 3, + MaxRetries: 3, + CreatedAt: time.Now().UTC().Add(-1 * time.Hour), + NextRetryAt: time.Now().UTC().Add(-1 * time.Minute), + }) + + c.ProcessOnce(ctx) + + assert.False(t, called, "retryFn must NOT be called for an exhausted message") + assert.Equal(t, 1, metrics.exhaustedCount(), "exhausted metric should be recorded once") + assert.Equal(t, 0, metrics.retriedCount(), "retried metric must not be recorded") + + // Message should be discarded — queue must be empty. + length, err := h.QueueLength(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, int64(0), length, "exhausted message must be discarded from the queue") +} + +// TestProcessOnce_CancelledContext verifies that ProcessOnce returns quickly +// and performs no work when its context is already cancelled before the call. +// This covers the ctx.Err() guard inside drainSource which short-circuits the +// dequeue loop on cancellation. +func TestProcessOnce_CancelledContext(t *testing.T) { + t.Parallel() + + metrics := &mockMetrics{} + + called := false + + retryFn := func(_ context.Context, _ *FailedMessage) error { + called = true + return nil + } + + c, h, mr := newTestConsumer(t, retryFn, metrics) + + // Inject a retryable message into Redis so there IS work to do in principle. + injectMessage(t, mr, "dlq:tenant-cancel:outbound", &FailedMessage{ + Source: "outbound", + TenantID: "tenant-cancel", + RetryCount: 0, + MaxRetries: 3, + CreatedAt: time.Now().UTC().Add(-5 * time.Minute), + NextRetryAt: time.Now().UTC().Add(-1 * time.Minute), + }) + + // Pre-cancel the context BEFORE calling ProcessOnce. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + c.ProcessOnce(ctx) + elapsed := time.Since(start) + + // ProcessOnce must return quickly — no blocking or long work. + assert.Less(t, elapsed, 2*time.Second, + "ProcessOnce with a pre-cancelled context should return quickly, took %s", elapsed) + + // No metrics should be recorded and retryFn must not have been invoked. + // NOTE: the ctx.Err() guard fires inside drainSource's select, AFTER + // ScanQueues has already been called (it runs against Redis directly). + // If ScanQueues completes before the context cancellation is checked, + // the message may be dequeued but then the keyCtx.Err() guard in + // processSource fires before drainSource is entered. Either way, + // retryFn must never be called. + assert.False(t, called, "retryFn must NOT be called when context is pre-cancelled") + + // Queue integrity: the message may or may not have been dequeued depending + // on where exactly the cancellation was observed; we only assert retryFn + // was not called and no metrics were emitted. + assert.Equal(t, 0, metrics.retriedCount(), "no retried metrics on cancelled context") + assert.Equal(t, 0, metrics.exhaustedCount(), "no exhausted metrics on cancelled context") + + // Suppress "h declared but not used" if compiler complains. + _ = h +} + +// TestNewConsumer_NilOptions verifies that passing nil ConsumerOption values to +// NewConsumer is safe — nil options are silently skipped and the Consumer is +// constructed with its defaults intact. +func TestNewConsumer_NilOptions(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newTestRedisClient(t, mr) + + h := New(conn, "dlq:", 3) + + // Pass a nil option alongside a valid one to verify nil-safety. + c, err := NewConsumer(h, + func(_ context.Context, _ *FailedMessage) error { return nil }, + nil, // nil option must be skipped + WithBatchSize(5), // valid option applied after nil + nil, // trailing nil also fine + ) + require.NoError(t, err, "nil ConsumerOption must not cause an error") + require.NotNil(t, c, "Consumer must be non-nil even when nil options are passed") + + assert.Equal(t, 5, c.cfg.BatchSize, "valid option after nil must still be applied") + assert.NotNil(t, c.logger, "logger must remain non-nil") + assert.NotNil(t, c.tracer, "tracer must remain non-nil") +} + +// TestIsRedisNilError covers the four cases for isRedisNilError: +// direct redis.Nil sentinel, wrapped sentinel, arbitrary error, and nil. +func TestIsRedisNilError(t *testing.T) { + t.Parallel() + + t.Run("direct redis.Nil returns true", func(t *testing.T) { + t.Parallel() + + assert.True(t, isRedisNilError(redis.Nil)) + }) + + t.Run("wrapped redis.Nil returns true", func(t *testing.T) { + t.Parallel() + + wrapped := fmt.Errorf("dlq: dequeue: %w", redis.Nil) + assert.True(t, isRedisNilError(wrapped)) + }) + + t.Run("non-redis error returns false", func(t *testing.T) { + t.Parallel() + + assert.False(t, isRedisNilError(errors.New("some other error"))) + }) + + t.Run("nil error returns false", func(t *testing.T) { + t.Parallel() + + assert.False(t, isRedisNilError(nil)) + }) +} diff --git a/commons/dlq/doc.go b/commons/dlq/doc.go new file mode 100644 index 00000000..ba713133 --- /dev/null +++ b/commons/dlq/doc.go @@ -0,0 +1,7 @@ +// Package dlq provides a Redis-backed dead letter queue with exponential backoff retry. +// +// Messages that fail processing are enqueued into tenant-isolated Redis lists +// and retried with configurable exponential backoff. A background Consumer polls +// these lists, invokes a caller-provided RetryFunc, and discards messages that +// either succeed or exhaust their retry budget. +package dlq diff --git a/commons/dlq/errors.go b/commons/dlq/errors.go new file mode 100644 index 00000000..be428af9 --- /dev/null +++ b/commons/dlq/errors.go @@ -0,0 +1,12 @@ +package dlq + +import "errors" + +var ( + // ErrNilHandler is returned when a Handler method is called on a nil receiver. + ErrNilHandler = errors.New("dlq handler is nil") + // ErrNilRetryFunc is returned when a Consumer is created without a retry function. + ErrNilRetryFunc = errors.New("dlq retry function is nil") + // ErrMessageExhausted indicates a message exceeded its maximum retry count. + ErrMessageExhausted = errors.New("dlq message exhausted all retries") +) diff --git a/commons/dlq/handler.go b/commons/dlq/handler.go new file mode 100644 index 00000000..63a92232 --- /dev/null +++ b/commons/dlq/handler.go @@ -0,0 +1,518 @@ +package dlq + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + "unicode/utf8" + + libBackoff "github.com/LerianStudio/lib-commons/v4/commons/backoff" + libLog "github.com/LerianStudio/lib-commons/v4/commons/log" + libOtel "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry" + libRedis "github.com/LerianStudio/lib-commons/v4/commons/redis" + tmcore "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core" + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" +) + +// DLQMetrics records DLQ-specific counters. Implementations are optional; +// when nil, metric calls are silently skipped. +type DLQMetrics interface { + RecordRetried(ctx context.Context, source string) + RecordExhausted(ctx context.Context, source string) + RecordLost(ctx context.Context, source string) +} + +// FailedMessage represents a message that failed processing and was routed to +// the dead letter queue for later retry. +type FailedMessage struct { + Source string `json:"source"` + OriginalData []byte `json:"original_data"` + ErrorMessage string `json:"error_message"` + RetryCount int `json:"retry_count"` + // MaxRetries is the maximum number of retry attempts. A value of 0 is treated + // as "use handler default" and will be overwritten during Enqueue. To allow + // zero retries (immediate discard on first failure), set MaxRetries to the + // handler's configured value and RetryCount to that same value. + MaxRetries int `json:"max_retries"` + CreatedAt time.Time `json:"created_at"` + NextRetryAt time.Time `json:"next_retry_at,omitzero"` + TenantID string `json:"tenant_id,omitempty"` +} + +// Handler manages dead letter queue operations backed by Redis lists. +type Handler struct { + conn *libRedis.Client + keyPrefix string + maxRetries int + logger libLog.Logger + tracer trace.Tracer + metrics DLQMetrics + module string +} + +// Option configures a Handler at construction time. +type Option func(*Handler) + +// WithLogger sets the logger used by the Handler. +func WithLogger(l libLog.Logger) Option { + return func(h *Handler) { + if l != nil { + h.logger = l + } + } +} + +// WithTracer sets the OpenTelemetry tracer used by the Handler. +func WithTracer(t trace.Tracer) Option { + return func(h *Handler) { + if t != nil { + h.tracer = t + } + } +} + +// WithMetrics sets the metrics recorder used by the Handler. +func WithMetrics(m DLQMetrics) Option { + return func(h *Handler) { + h.metrics = m + } +} + +// WithModule sets a module label used in log and metric context. +func WithModule(module string) Option { + return func(h *Handler) { + if module != "" { + h.module = module + } + } +} + +// New creates a Handler backed by the given Redis client. keyPrefix is prepended +// to all Redis keys (e.g. "dlq:"). maxRetries controls how many times a message +// may be retried before it is considered exhausted. +// Returns nil when conn is nil — all exported Handler methods already guard +// against a nil receiver and return ErrNilHandler, so callers are safe. +func New(conn *libRedis.Client, keyPrefix string, maxRetries int, opts ...Option) *Handler { + if conn == nil { + return nil + } + + if maxRetries <= 0 { + maxRetries = 3 + } + + if keyPrefix == "" { + keyPrefix = "dlq:" + } + + h := &Handler{ + conn: conn, + keyPrefix: keyPrefix, + maxRetries: maxRetries, + logger: libLog.NewNop(), + tracer: noop.NewTracerProvider().Tracer("dlq.noop"), + } + + for _, opt := range opts { + if opt != nil { + opt(h) + } + } + + return h +} + +// backoffDuration calculates exponential backoff with jitter for retry timing. +// Base delay of 30s matches br-spb's original formula. Uses AWS Full Jitter +// strategy via lib-commons/backoff for better cluster behavior. +// The floor is 5s (not 30s) so that attempt 0 gets genuine jitter spread +// over [5s, 30s) rather than always resolving to exactly 30s. +func backoffDuration(retryCount int) time.Duration { + const minBackoff = 5 * time.Second + + d := libBackoff.ExponentialWithJitter(30*time.Second, retryCount) + + return max(d, minBackoff) +} + +// Enqueue adds a failed message to the DLQ. The message's TenantID is +// resolved from the context when not already set on the message itself. +// If msg.MaxRetries is 0 on the initial enqueue (CreatedAt is zero), it is +// overwritten with the handler's configured maxRetries value. See the +// MaxRetries field doc for how to express a "zero retries allowed" policy. +func (h *Handler) Enqueue(ctx context.Context, msg *FailedMessage) error { + if h == nil { + return ErrNilHandler + } + + if msg == nil { + return errors.New("dlq: enqueue: nil message") + } + + if msg.Source == "" { + return errors.New("dlq: enqueue: source must not be empty") + } + + if err := validateKeySegment("source", msg.Source); err != nil { + return err + } + + ctx, span := h.tracer.Start(ctx, "dlq.enqueue") + defer span.End() + + // Only stamp CreatedAt and MaxRetries on initial enqueue (zero-valued). + // Re-enqueue paths (consumer retry-failed, not-yet-ready, prune) pass + // messages that already carry the original values; overwriting them would + // permanently lose the original failure timestamp and retry budget. + initialEnqueue := msg.CreatedAt.IsZero() + if initialEnqueue { + msg.CreatedAt = time.Now().UTC() + } + + if msg.MaxRetries == 0 { + msg.MaxRetries = h.maxRetries + } + + ctxTenant := tmcore.GetTenantIDContext(ctx) + + effectiveTenant := msg.TenantID + if effectiveTenant == "" { + effectiveTenant = ctxTenant + msg.TenantID = effectiveTenant + } + + if effectiveTenant != "" && ctxTenant != "" && effectiveTenant != ctxTenant { + return fmt.Errorf("dlq: enqueue: tenant mismatch between message (%s) and context (%s)", effectiveTenant, ctxTenant) + } + + // Recalculate NextRetryAt only on initial enqueue. On re-enqueue the + // consumer has already incremented RetryCount and the caller is + // responsible for timing; we preserve their NextRetryAt or let the + // backoff be recalculated by the consumer path that sets RetryCount. + if initialEnqueue && msg.RetryCount < msg.MaxRetries { + msg.NextRetryAt = msg.CreatedAt.Add(backoffDuration(msg.RetryCount)) + } + + data, err := json.Marshal(msg) + if err != nil { + libOtel.HandleSpanError(span, "dlq marshal failed", err) + + return fmt.Errorf("dlq: enqueue: marshal: %w", err) + } + + key := h.tenantScopedKeyForTenant(effectiveTenant, msg.Source) + + rds, err := h.conn.GetClient(ctx) + if err != nil { + libOtel.HandleSpanError(span, "dlq redis client unavailable", err) + h.logEnqueueFallback(ctx, key, msg, err) + + return fmt.Errorf("dlq: enqueue: redis client: %w", err) + } + + if pushErr := rds.RPush(ctx, key, data).Err(); pushErr != nil { + libOtel.HandleSpanError(span, "dlq rpush failed", pushErr) + h.logEnqueueFallback(ctx, key, msg, pushErr) + + return fmt.Errorf("dlq: enqueue: rpush: %w", pushErr) + } + + return nil +} + +// logEnqueueFallback logs message metadata when Redis is unreachable. The +// payload is redacted to prevent PII leakage into log aggregators. +func (h *Handler) logEnqueueFallback(ctx context.Context, key string, msg *FailedMessage, err error) { + h.logger.Log(ctx, libLog.LevelError, + "dlq: failed to enqueue message to Redis — payload redacted for PII safety", + libLog.String("dlq_key", key), + libLog.String("msg_source", msg.Source), + libLog.Int("retry_count", msg.RetryCount), + libLog.String("original_error", truncateString(msg.ErrorMessage, 200)), + libLog.Err(err), + ) +} + +// Dequeue atomically removes and returns the next message from the given source queue. +// NOTE: This uses LPop which is destructive. If the process crashes between Dequeue +// and a subsequent re-enqueue, the message is permanently lost. This provides +// at-most-once delivery semantics. For at-least-once, consider using LMOVE (Redis 6.2+). +func (h *Handler) Dequeue(ctx context.Context, source string) (*FailedMessage, error) { + if h == nil { + return nil, ErrNilHandler + } + + if err := validateKeySegment("source", source); err != nil { + return nil, err + } + + ctx, span := h.tracer.Start(ctx, "dlq.dequeue") + defer span.End() + + key := h.tenantScopedKey(ctx, source) + + rds, err := h.conn.GetClient(ctx) + if err != nil { + libOtel.HandleSpanError(span, "dlq redis client unavailable", err) + + return nil, fmt.Errorf("dlq: dequeue: redis client: %w", err) + } + + data, err := rds.LPop(ctx, key).Result() + if err != nil { + return nil, fmt.Errorf("dlq: dequeue: %w", err) + } + + var msg FailedMessage + if err := json.Unmarshal([]byte(data), &msg); err != nil { + libOtel.HandleSpanError(span, "dlq unmarshal failed", err) + + return nil, fmt.Errorf("dlq: dequeue: unmarshal: %w", err) + } + + return &msg, nil +} + +// QueueLength returns the number of messages in the DLQ for the given source. +func (h *Handler) QueueLength(ctx context.Context, source string) (int64, error) { + if h == nil { + return 0, ErrNilHandler + } + + if err := validateKeySegment("source", source); err != nil { + return 0, err + } + + key := h.tenantScopedKey(ctx, source) + + rds, err := h.conn.GetClient(ctx) + if err != nil { + return 0, fmt.Errorf("dlq: queue length: redis client: %w", err) + } + + return rds.LLen(ctx, key).Result() +} + +// ScanQueues discovers all tenant-scoped Redis keys matching the pattern +// "{keyPrefix}*:{source}". This enables a background consumer (running without +// tenant context) to find keys like "dlq:tenant-A:outbound". +// +// The SCAN command is used instead of KEYS to avoid blocking Redis on large +// keyspaces. Returns full Redis keys; the caller can use ExtractTenantFromKey +// to recover the tenant ID. +func (h *Handler) ScanQueues(ctx context.Context, source string) ([]string, error) { + if h == nil { + return nil, ErrNilHandler + } + + if err := validateKeySegment("source", source); err != nil { + return nil, err + } + + ctx, span := h.tracer.Start(ctx, "dlq.scan_queues") + defer span.End() + + pattern := fmt.Sprintf("%s*:%s", h.keyPrefix, source) + globalKey := fmt.Sprintf("%s%s", h.keyPrefix, source) + + rds, err := h.conn.GetClient(ctx) + if err != nil { + libOtel.HandleSpanError(span, "dlq redis client unavailable", err) + + return nil, fmt.Errorf("dlq: scan queues: redis client: %w", err) + } + + var keys []string + + var cursor uint64 + + for { + var batch []string + + var scanErr error + + batch, cursor, scanErr = rds.Scan(ctx, cursor, pattern, 100).Result() + if scanErr != nil { + libOtel.HandleSpanError(span, "dlq scan failed", scanErr) + + return nil, fmt.Errorf("dlq: scan queues: %w", scanErr) + } + + for _, key := range batch { + if key != globalKey { + keys = append(keys, key) + } + } + + if cursor == 0 { + break + } + } + + return keys, nil +} + +// PruneExhaustedMessages removes up to limit messages from the DLQ source that +// have exceeded their maximum retry count. Returns the number of messages pruned. +// +// NOTE: This uses LPop (via Dequeue) which is destructive. If the process crashes +// between Dequeue and a subsequent re-enqueue of non-exhausted messages, those +// messages are permanently lost. This provides at-most-once delivery semantics. +// For at-least-once, consider using LMOVE (Redis 6.2+). +// +// Note: surviving messages are re-enqueued at the back of the queue. FIFO +// ordering relative to other messages in the same source is not preserved. +// This is acceptable for a dead letter queue — messages routed here are already +// out of their original processing order by definition — but callers that +// depend on strict ordering should be aware of this behavior. +func (h *Handler) PruneExhaustedMessages(ctx context.Context, source string, limit int) (int, error) { + if h == nil { + return 0, ErrNilHandler + } + + if err := validateKeySegment("source", source); err != nil { + return 0, err + } + + if limit <= 0 { + return 0, nil + } + + ctx, span := h.tracer.Start(ctx, "dlq.prune_exhausted") + defer span.End() + + pruned := 0 + + for range limit { + msg, err := h.Dequeue(ctx, source) + if err != nil { + if isRedisNilError(err) { + // Empty queue — done. + break + } + + // Real error (Redis failure, JSON corruption) — propagate. + return pruned, fmt.Errorf("dlq: prune: dequeue: %w", err) + } + + if msg.RetryCount >= msg.MaxRetries { + pruned++ + + h.logger.Log(ctx, libLog.LevelWarn, "dlq: pruned exhausted message", + libLog.String("source", msg.Source), + libLog.Int("retry_count", msg.RetryCount), + libLog.String("tenant_id", msg.TenantID), + ) + + continue + } + + // Not exhausted — put it back. + if err := h.Enqueue(ctx, msg); err != nil { + h.logger.Log(ctx, libLog.LevelError, "dlq: failed to re-enqueue non-exhausted message during prune", + libLog.String("source", msg.Source), + libLog.Err(err), + ) + + return pruned, fmt.Errorf("dlq: prune: re-enqueue: %w", err) + } + } + + return pruned, nil +} + +// ExtractTenantFromKey extracts the tenant ID from a tenant-scoped Redis key. +// Given key="dlq:tenant-abc:outbound" and keyPrefix="dlq:", returns "tenant-abc". +// Returns empty string if the key does not match the expected format. +func (h *Handler) ExtractTenantFromKey(key, source string) string { + if h == nil { + return "" + } + + prefix := h.keyPrefix + suffix := ":" + source + + if len(key) <= len(prefix)+len(suffix) { + return "" + } + + if key[:len(prefix)] != prefix { + return "" + } + + if key[len(key)-len(suffix):] != suffix { + return "" + } + + return key[len(prefix) : len(key)-len(suffix)] +} + +// tenantScopedKey constructs a Redis key including the tenant ID from context. +// With tenant: "dlq:tenant-abc:outbound" +// Without tenant: "dlq:outbound" +func (h *Handler) tenantScopedKey(ctx context.Context, source string) string { + return h.tenantScopedKeyForTenant(tmcore.GetTenantIDContext(ctx), source) +} + +func (h *Handler) tenantScopedKeyForTenant(tenantID, source string) string { + if tenantID != "" { + if err := validateKeySegment("tenantID", tenantID); err != nil { + // Log the invalid segment and fall back to the non-tenant key + // rather than constructing a corrupted Redis key. + h.logger.Log(context.Background(), libLog.LevelWarn, "dlq: tenantScopedKeyForTenant: invalid tenantID, using global key", + libLog.String("tenant_id", tenantID), + libLog.Err(err), + ) + + return fmt.Sprintf("%s%s", h.keyPrefix, source) + } + + return fmt.Sprintf("%s%s:%s", h.keyPrefix, tenantID, source) + } + + return fmt.Sprintf("%s%s", h.keyPrefix, source) +} + +// validateKeySegment ensures that a Redis key segment (source or tenantID) +// does not contain characters that would corrupt key patterns or enable +// injection into SCAN glob patterns. Backslash is included because it is +// the escape character in Redis SCAN patterns. +func validateKeySegment(name, value string) error { + for _, c := range value { + if c == ':' || c == '*' || c == '?' || c == '[' || c == ']' || c == '\\' { + return fmt.Errorf("dlq: %s %q contains disallowed character %q", name, value, c) + } + } + + return nil +} + +// truncateString returns s unchanged when its rune count is within maxLen, +// otherwise it truncates at the maxLen-th rune boundary and appends "..." +// to signal that content was trimmed. Rune-aware truncation prevents +// splitting multi-byte UTF-8 sequences, which would produce invalid output. +// Used by logEnqueueFallback to prevent large PII-containing error messages +// from leaking into log aggregators. +func truncateString(s string, maxLen int) string { + if utf8.RuneCountInString(s) <= maxLen { + return s + } + + // Walk runes and cut at the correct byte offset. + byteLen := 0 + count := 0 + + for _, r := range s { + if count == maxLen { + break + } + + byteLen += utf8.RuneLen(r) + count++ + } + + return s[:byteLen] + "..." +} diff --git a/commons/dlq/handler_test.go b/commons/dlq/handler_test.go new file mode 100644 index 00000000..7519b3eb --- /dev/null +++ b/commons/dlq/handler_test.go @@ -0,0 +1,669 @@ +//go:build unit + +package dlq + +import ( + "context" + "testing" + "time" + + libLog "github.com/LerianStudio/lib-commons/v4/commons/log" + libRedis "github.com/LerianStudio/lib-commons/v4/commons/redis" + tmcore "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core" + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" +) + +// newTestRedisClient creates a *libRedis.Client backed by miniredis. +// Follows the exact pattern from commons/redis/lock_test.go and +// commons/net/http/ratelimit/redis_storage_test.go. +func newTestRedisClient(t *testing.T, mr *miniredis.Miniredis) *libRedis.Client { + t.Helper() + + conn, err := libRedis.New(context.Background(), libRedis.Config{ + Topology: libRedis.Topology{ + Standalone: &libRedis.StandaloneTopology{Address: mr.Addr()}, + }, + Logger: &libLog.NopLogger{}, + }) + require.NoError(t, err) + + t.Cleanup(func() { + if err := conn.Close(); err != nil { + t.Logf("newTestRedisClient cleanup: conn.Close() error: %v", err) + } + }) + + return conn +} + +// newTestHandler creates a Handler wired to a fresh miniredis instance. +func newTestHandler(t *testing.T) (*Handler, *miniredis.Miniredis) { + t.Helper() + + mr := miniredis.RunT(t) + conn := newTestRedisClient(t, mr) + + h := New(conn, "dlq:", 3) + + return h, mr +} + +func TestNew_Defaults(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newTestRedisClient(t, mr) + + h := New(conn, "", 0) + + assert.Equal(t, 3, h.maxRetries, "default maxRetries should be 3") + assert.Equal(t, "dlq:", h.keyPrefix, "default keyPrefix should be 'dlq:'") + assert.NotNil(t, h.logger, "logger should never be nil") + assert.NotNil(t, h.tracer, "tracer should never be nil") + assert.Nil(t, h.metrics, "metrics should be nil when not provided") +} + +func TestNew_NilConn(t *testing.T) { + t.Parallel() + + h := New(nil, "dlq:", 3) + assert.Nil(t, h, "New should return nil when conn is nil") + + // Nil *Handler is safe to use — all methods return ErrNilHandler. + err := h.Enqueue(context.Background(), &FailedMessage{Source: "test"}) + require.ErrorIs(t, err, ErrNilHandler) + + _, err = h.Dequeue(context.Background(), "test") + require.ErrorIs(t, err, ErrNilHandler) + + _, err = h.QueueLength(context.Background(), "test") + require.ErrorIs(t, err, ErrNilHandler) +} + +func TestNew_WithOptions(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newTestRedisClient(t, mr) + + logger := libLog.NewNop() + tracer := noop.NewTracerProvider().Tracer("test") + m := &mockMetrics{} + + h := New(conn, "custom:", 5, + WithLogger(logger), + WithTracer(tracer), + WithMetrics(m), + WithModule("payments"), + ) + + assert.Equal(t, 5, h.maxRetries) + assert.Equal(t, "custom:", h.keyPrefix) + assert.Equal(t, "payments", h.module) + assert.Same(t, m, h.metrics, "metrics should be the supplied mock") +} + +func TestEnqueue_NilHandler(t *testing.T) { + t.Parallel() + + var h *Handler + + err := h.Enqueue(context.Background(), &FailedMessage{Source: "test"}) + require.ErrorIs(t, err, ErrNilHandler) +} + +func TestEnqueue_NilMessage(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + + err := h.Enqueue(context.Background(), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "nil message") +} + +func TestEnqueue_Success(t *testing.T) { + t.Parallel() + + h, mr := newTestHandler(t) + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + msg := &FailedMessage{ + Source: "outbound", + OriginalData: []byte(`{"id":1}`), + ErrorMessage: "timeout", + } + + err := h.Enqueue(ctx, msg) + require.NoError(t, err) + + // Verify the message landed in the correct Redis key. + key := "dlq:tenant-abc:outbound" + items, err := mr.List(key) + require.NoError(t, err) + assert.Len(t, items, 1, "queue should contain exactly one message") + + // Verify message fields were set by Enqueue. + assert.Equal(t, h.maxRetries, msg.MaxRetries, "MaxRetries should be stamped from handler") + assert.Equal(t, "tenant-abc", msg.TenantID, "TenantID should be resolved from context") + assert.False(t, msg.CreatedAt.IsZero(), "CreatedAt should be set") + assert.False(t, msg.NextRetryAt.IsZero(), "NextRetryAt should be set for retryable messages") +} + +func TestEnqueue_TenantFromContext(t *testing.T) { + t.Parallel() + + h, mr := newTestHandler(t) + ctx := tmcore.ContextWithTenantID(context.Background(), "ctx-tenant") + + msg := &FailedMessage{ + Source: "inbound", + OriginalData: []byte(`{}`), + ErrorMessage: "fail", + // TenantID intentionally left empty — should be resolved from ctx. + } + + err := h.Enqueue(ctx, msg) + require.NoError(t, err) + assert.Equal(t, "ctx-tenant", msg.TenantID) + + items, err := mr.List("dlq:ctx-tenant:inbound") + require.NoError(t, err) + assert.Len(t, items, 1) +} + +func TestEnqueue_TenantMismatch(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-A") + + msg := &FailedMessage{ + Source: "outbound", + TenantID: "tenant-B", + } + + err := h.Enqueue(ctx, msg) + require.Error(t, err) + assert.Contains(t, err.Error(), "tenant mismatch") +} + +func TestDequeue_Success(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + original := &FailedMessage{ + Source: "outbound", + OriginalData: []byte(`{"key":"value"}`), + ErrorMessage: "connection refused", + } + + require.NoError(t, h.Enqueue(ctx, original)) + + got, err := h.Dequeue(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, "outbound", got.Source) + assert.Equal(t, "tenant-abc", got.TenantID) + assert.Equal(t, []byte(`{"key":"value"}`), got.OriginalData) +} + +func TestDequeue_EmptyQueue(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + _, err := h.Dequeue(ctx, "nonexistent") + require.Error(t, err) + assert.Contains(t, err.Error(), "redis: nil") +} + +func TestDequeue_NilHandler(t *testing.T) { + t.Parallel() + + var h *Handler + + _, err := h.Dequeue(context.Background(), "test") + require.ErrorIs(t, err, ErrNilHandler) +} + +func TestQueueLength_Empty(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + length, err := h.QueueLength(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, int64(0), length) +} + +func TestQueueLength_NonEmpty(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + for i := range 5 { + require.NoError(t, h.Enqueue(ctx, &FailedMessage{ + Source: "outbound", + ErrorMessage: "err", + OriginalData: []byte{byte(i)}, + })) + } + + length, err := h.QueueLength(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, int64(5), length) +} + +// TestEnqueue_EmptySource_Rejected verifies that Enqueue rejects messages with +// an empty source. An empty source would produce a malformed Redis key and +// is not a valid routing target. +func TestEnqueue_EmptySource_Rejected(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + msg := &FailedMessage{ + Source: "", + OriginalData: []byte(`{}`), + ErrorMessage: "err", + } + + err := h.Enqueue(ctx, msg) + require.Error(t, err, "empty source should be rejected") + assert.Contains(t, err.Error(), "source must not be empty") +} + +func TestExtractTenantFromKey(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + + tests := []struct { + name string + key string + source string + want string + }{ + { + name: "standard tenant key", + key: "dlq:tenant-abc:outbound", + source: "outbound", + want: "tenant-abc", + }, + { + name: "uuid tenant", + key: "dlq:550e8400-e29b-41d4-a716-446655440000:inbound", + source: "inbound", + want: "550e8400-e29b-41d4-a716-446655440000", + }, + { + name: "no tenant segment (global key)", + key: "dlq:outbound", + source: "outbound", + want: "", + }, + { + name: "wrong prefix", + key: "other:tenant-abc:outbound", + source: "outbound", + want: "", + }, + { + name: "wrong suffix", + key: "dlq:tenant-abc:wrong", + source: "outbound", + want: "", + }, + { + name: "empty key", + key: "", + source: "outbound", + want: "", + }, + { + name: "nil handler returns empty", + key: "dlq:tenant:outbound", + source: "outbound", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + target := h + if tt.name == "nil handler returns empty" { + target = nil + } + + got := target.ExtractTenantFromKey(tt.key, tt.source) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestScanQueues_Success(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + + tenants := []string{"tenant-A", "tenant-B", "tenant-C"} + + for _, tid := range tenants { + ctx := tmcore.ContextWithTenantID(context.Background(), tid) + require.NoError(t, h.Enqueue(ctx, &FailedMessage{ + Source: "outbound", + OriginalData: []byte(`{}`), + ErrorMessage: "err", + })) + } + + keys, err := h.ScanQueues(context.Background(), "outbound") + require.NoError(t, err) + assert.Len(t, keys, 3, "should discover all three tenant-scoped keys") + + // Verify each discovered key maps back to a known tenant. + discovered := make(map[string]bool) + for _, key := range keys { + tid := h.ExtractTenantFromKey(key, "outbound") + discovered[tid] = true + } + + for _, tid := range tenants { + assert.True(t, discovered[tid], "tenant %s should be discovered by ScanQueues", tid) + } +} + +func TestPruneExhaustedMessages(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + // Enqueue two exhausted messages (RetryCount >= MaxRetries). + for range 2 { + require.NoError(t, h.Enqueue(ctx, &FailedMessage{ + Source: "outbound", + OriginalData: []byte(`{}`), + ErrorMessage: "permanent failure", + RetryCount: 3, // = maxRetries + })) + } + + // Enqueue one non-exhausted message. + require.NoError(t, h.Enqueue(ctx, &FailedMessage{ + Source: "outbound", + OriginalData: []byte(`{}`), + ErrorMessage: "transient failure", + RetryCount: 0, + })) + + pruned, err := h.PruneExhaustedMessages(ctx, "outbound", 10) + require.NoError(t, err) + assert.Equal(t, 2, pruned, "exactly two exhausted messages should be pruned") + + // The non-exhausted message should still be in the queue. + length, err := h.QueueLength(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, int64(1), length, "one non-exhausted message should remain") + + remaining, err := h.Dequeue(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, 0, remaining.RetryCount, "surviving message should be the non-exhausted one") +} + +// TestDequeue_MalformedJSON verifies that Dequeue returns an error (not a nil +// pointer dereference) when the Redis list contains non-JSON data. This covers +// the case where a message was manually injected or corrupted in transit. +func TestDequeue_MalformedJSON(t *testing.T) { + t.Parallel() + + h, mr := newTestHandler(t) + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + // Inject raw non-JSON bytes directly via miniredis, bypassing Enqueue. + mr.RPush("dlq:tenant-abc:outbound", "this is not valid json {{{") + + got, err := h.Dequeue(ctx, "outbound") + require.Error(t, err, "malformed JSON payload should return an error") + assert.Nil(t, got, "result pointer should be nil on unmarshal failure") + assert.Contains(t, err.Error(), "unmarshal", "error should indicate unmarshal failure") +} + +// TestBackoffDuration verifies the backoffDuration helper: +// - floor is at least 5 seconds for any retry count +// - values grow in trend as retryCount increases +// - jitter causes variance (calling twice rarely returns identical values) +func TestBackoffDuration(t *testing.T) { + t.Parallel() + + minFloor := 5 * time.Second + + for _, count := range []int{0, 1, 5} { + d := backoffDuration(count) + assert.GreaterOrEqual(t, d, minFloor, + "backoffDuration(%d) should be >= 5s floor, got %s", count, d) + } + + // Higher retry counts should not produce values below floor. + d5 := backoffDuration(5) + assert.GreaterOrEqual(t, d5, minFloor, "retryCount=5 must still respect floor") + + // Jitter: calling the function multiple times for the same retryCount + // should eventually produce different values (probabilistic — we allow + // up to 20 attempts before concluding there is no variance). + const attempts = 20 + + seenVariance := false + + first := backoffDuration(2) + + for range attempts { + next := backoffDuration(2) + if next != first { + seenVariance = true + + break + } + } + + assert.True(t, seenVariance, "backoffDuration should exhibit jitter variance across multiple calls") +} + +// TestEnqueue_InvalidSourceChars verifies that Enqueue rejects source strings +// that contain characters which would corrupt Redis key patterns or enable +// glob-injection into SCAN commands. +func TestEnqueue_InvalidSourceChars(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + ctx := context.Background() + + invalidChars := []struct { + name string + source string + }{ + {"colon", "out:bound"}, + {"asterisk", "out*bound"}, + {"question-mark", "out?bound"}, + {"open-bracket", "out[bound"}, + {"close-bracket", "out]bound"}, + } + + for _, tc := range invalidChars { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := h.Enqueue(ctx, &FailedMessage{ + Source: tc.source, + OriginalData: []byte(`{}`), + ErrorMessage: "err", + }) + require.Error(t, err, "source %q containing disallowed char should be rejected", tc.source) + assert.Contains(t, err.Error(), "disallowed character", + "error for source %q should mention disallowed character", tc.source) + }) + } +} + +// TestScanQueues_NilHandler verifies that ScanQueues on a nil Handler returns +// ErrNilHandler without panicking. +func TestScanQueues_NilHandler(t *testing.T) { + t.Parallel() + + var h *Handler + + keys, err := h.ScanQueues(context.Background(), "outbound") + require.ErrorIs(t, err, ErrNilHandler) + assert.Nil(t, keys) +} + +// TestPruneExhaustedMessages_NilHandler verifies that PruneExhaustedMessages on +// a nil Handler returns ErrNilHandler without panicking. +func TestPruneExhaustedMessages_NilHandler(t *testing.T) { + t.Parallel() + + var h *Handler + + pruned, err := h.PruneExhaustedMessages(context.Background(), "outbound", 10) + require.ErrorIs(t, err, ErrNilHandler) + assert.Equal(t, 0, pruned) +} + +// TestPruneExhaustedMessages_ZeroLimit verifies that PruneExhaustedMessages with +// limit=0 (or negative) returns immediately with zero pruned and no error. +// This is a short-circuit guard in the implementation: a zero/negative limit +// means "prune nothing", which is a valid no-op call — not an error condition. +func TestPruneExhaustedMessages_ZeroLimit(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-abc") + + // Enqueue an exhausted message so we can verify it is NOT consumed when limit=0. + require.NoError(t, h.Enqueue(ctx, &FailedMessage{ + Source: "outbound", + OriginalData: []byte(`{}`), + ErrorMessage: "permanent failure", + RetryCount: 3, // = maxRetries + })) + + pruned, err := h.PruneExhaustedMessages(ctx, "outbound", 0) + require.NoError(t, err, "limit=0 should not return an error") + assert.Equal(t, 0, pruned, "limit=0 should prune zero messages") + + // The exhausted message should still be in the queue — we did not touch it. + length, err := h.QueueLength(ctx, "outbound") + require.NoError(t, err) + assert.Equal(t, int64(1), length, "message should remain in queue after limit=0 prune call") + + // Also verify negative limit has the same short-circuit behaviour. + pruned, err = h.PruneExhaustedMessages(ctx, "outbound", -5) + require.NoError(t, err, "negative limit should not return an error") + assert.Equal(t, 0, pruned, "negative limit should prune zero messages") +} + +// TestDequeue_Success_CompleteAssertions extends the basic Dequeue success case +// to assert every FailedMessage field is correctly round-tripped through the +// enqueue → dequeue cycle. +func TestDequeue_Success_CompleteAssertions(t *testing.T) { + t.Parallel() + + h, _ := newTestHandler(t) + ctx := tmcore.ContextWithTenantID(context.Background(), "tenant-xyz") + + original := &FailedMessage{ + Source: "payments", + OriginalData: []byte(`{"amount":42}`), + ErrorMessage: "gateway timeout", + } + + before := time.Now().UTC() + require.NoError(t, h.Enqueue(ctx, original)) + after := time.Now().UTC() + + got, err := h.Dequeue(ctx, "payments") + require.NoError(t, err) + require.NotNil(t, got) + + assert.Equal(t, "payments", got.Source, "Source should round-trip") + assert.Equal(t, "tenant-xyz", got.TenantID, "TenantID should be set from context") + assert.Equal(t, []byte(`{"amount":42}`), got.OriginalData, "OriginalData should round-trip") + assert.Equal(t, "gateway timeout", got.ErrorMessage, "ErrorMessage should round-trip") + assert.Equal(t, 0, got.RetryCount, "RetryCount should be 0 on first enqueue") + assert.Equal(t, h.maxRetries, got.MaxRetries, "MaxRetries should be stamped from handler default") + assert.False(t, got.CreatedAt.IsZero(), "CreatedAt should not be zero after enqueue") + assert.True(t, !got.CreatedAt.Before(before) && !got.CreatedAt.After(after), + "CreatedAt should be within the test window [%s, %s], got %s", before, after, got.CreatedAt) + assert.False(t, got.NextRetryAt.IsZero(), "NextRetryAt should be set for a retryable message") + assert.True(t, got.NextRetryAt.After(got.CreatedAt), + "NextRetryAt should be in the future relative to CreatedAt") +} + +// TestTruncateString verifies the truncateString helper across its boundary cases. +func TestTruncateString(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + maxLen int + want string + }{ + { + name: "empty string", + input: "", + maxLen: 10, + want: "", + }, + { + name: "short string under limit", + input: "hello", + maxLen: 10, + want: "hello", + }, + { + name: "string at exact limit", + input: "helloworld", + maxLen: 10, + want: "helloworld", + }, + { + name: "string over limit", + input: "helloworld!", + maxLen: 10, + want: "helloworld...", + }, + { + name: "long string gets truncated with ellipsis", + input: "this is a very long error message that must be trimmed", + maxLen: 20, + want: "this is a very long ...", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := truncateString(tc.input, tc.maxLen) + assert.Equal(t, tc.want, got) + + // For over-limit cases, the result should end in "...". + if len(tc.input) > tc.maxLen { + assert.True(t, len(got) == tc.maxLen+3, + "truncated string should be maxLen+3 bytes (maxLen + ellipsis), got len=%d", len(got)) + } + }) + } +} diff --git a/commons/net/http/idempotency/doc.go b/commons/net/http/idempotency/doc.go new file mode 100644 index 00000000..dfd0c50c --- /dev/null +++ b/commons/net/http/idempotency/doc.go @@ -0,0 +1,50 @@ +// Package idempotency provides Fiber middleware for best-effort idempotency +// backed by Redis. +// +// The middleware enforces at-most-once semantics when Redis is available. On +// Redis outages, it fails open to preserve service availability — duplicate +// requests may execute more than once. Callers that require strict at-most-once +// guarantees must pair this middleware with application-level safeguards. +// +// The middleware uses the X-Idempotency request header combined with the tenant ID +// (from tenant-manager context) to form a composite Redis key. When a tenant ID is +// present, keys are scoped per-tenant to prevent cross-tenant collisions. When no +// tenant is in context (e.g., non-tenant-scoped routes), keys are still valid but +// are shared across the global namespace. Duplicate requests receive the original +// response with the X-Idempotency-Replayed header set to "true". +// +// # Quick start +// +// conn, err := redis.New(ctx, cfg) +// if err != nil { +// log.Fatal(err) +// } +// idem := idempotency.New(conn) +// app.Post("/orders", idem.Check(), createOrderHandler) +// +// # Behavior +// +// - GET/OPTIONS requests pass through (idempotency is irrelevant for reads). +// - Absent header: request proceeds normally (idempotency is opt-in). +// - Header exceeds MaxKeyLength (default 256): request rejected with 400. +// - Duplicate key: cached response replayed with X-Idempotency-Replayed: true. +// - Redis failure: request proceeds (fail-open for availability). +// - Handler success: response cached; handler failure: key deleted (client can retry). +// +// # Redis key namespace convention +// +// Keys follow the pattern: : +// with a companion response key at ::response. +// +// The default prefix is "idempotency:" and can be overridden via WithKeyPrefix. +// This namespacing convention is consistent with other lib-commons packages that +// use Redis (e.g., rate limiting uses "ratelimit::..."). Per-tenant +// isolation is enforced by embedding the tenant ID into the key rather than +// using separate Redis databases or key-space notifications, which keeps the +// implementation topology-agnostic (standalone, sentinel, and cluster all behave +// identically with this approach). +// +// # Nil safety +// +// A nil *Middleware returns a pass-through handler from Check(). +package idempotency diff --git a/commons/net/http/idempotency/idempotency.go b/commons/net/http/idempotency/idempotency.go new file mode 100644 index 00000000..b4714e2d --- /dev/null +++ b/commons/net/http/idempotency/idempotency.go @@ -0,0 +1,308 @@ +package idempotency + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + chttp "github.com/LerianStudio/lib-commons/v4/commons/constants" + "github.com/LerianStudio/lib-commons/v4/commons/log" + libRedis "github.com/LerianStudio/lib-commons/v4/commons/redis" + tmcore "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core" + "github.com/gofiber/fiber/v2" + "github.com/redis/go-redis/v9" +) + +const ( + keyStateProcessing = "processing" + keyStateComplete = "complete" +) + +// cachedResponse stores the full HTTP response for idempotent replay. +type cachedResponse struct { + StatusCode int `json:"status_code"` + ContentType string `json:"content_type"` + Body string `json:"body"` +} + +// Option configures the idempotency middleware. +type Option func(*Middleware) + +// Middleware provides at-most-once request semantics using Redis SetNX. +type Middleware struct { + conn *libRedis.Client + logger log.Logger + keyPrefix string + keyTTL time.Duration + maxKeyLength int + maxBodyCache int + redisTimeout time.Duration + onRejected func(c *fiber.Ctx) error +} + +// New creates an idempotency middleware backed by the given Redis client. +// Returns nil if conn is nil (nil-safe: Check() returns pass-through). +func New(conn *libRedis.Client, opts ...Option) *Middleware { + if conn == nil { + return nil + } + + m := &Middleware{ + conn: conn, + logger: log.NewNop(), + keyPrefix: "idempotency:", + keyTTL: 7 * 24 * time.Hour, + maxKeyLength: 256, + maxBodyCache: 1 << 20, // 1 MB default + redisTimeout: 500 * time.Millisecond, + } + + for _, opt := range opts { + if opt != nil { + opt(m) + } + } + + return m +} + +// WithLogger sets a structured logger. +func WithLogger(l log.Logger) Option { + return func(m *Middleware) { + if l != nil { + m.logger = l + } + } +} + +// WithKeyPrefix sets the Redis key prefix (default: "idempotency:"). +func WithKeyPrefix(prefix string) Option { + return func(m *Middleware) { + if prefix != "" { + m.keyPrefix = prefix + } + } +} + +// WithKeyTTL sets how long idempotency keys are retained (default: 7 days). +func WithKeyTTL(ttl time.Duration) Option { + return func(m *Middleware) { + if ttl > 0 { + m.keyTTL = ttl + } + } +} + +// WithMaxKeyLength sets the maximum allowed idempotency key length (default: 256). +func WithMaxKeyLength(n int) Option { + return func(m *Middleware) { + if n > 0 { + m.maxKeyLength = n + } + } +} + +// WithRedisTimeout sets the timeout for Redis operations (default: 500ms). +func WithRedisTimeout(d time.Duration) Option { + return func(m *Middleware) { + if d > 0 { + m.redisTimeout = d + } + } +} + +// WithRejectedHandler sets a custom handler for requests with oversized keys. +// By default, a generic 400 JSON response is returned. +func WithRejectedHandler(fn func(c *fiber.Ctx) error) Option { + return func(m *Middleware) { + m.onRejected = fn + } +} + +// WithMaxBodyCache sets the maximum response body size (in bytes) that will be +// cached in Redis for idempotent replay (default: 1 MB). Responses larger than +// this limit are not cached; duplicate requests will receive a generic +// "already processed" response instead. +// Values <= 0 are ignored. +func WithMaxBodyCache(n int) Option { + return func(m *Middleware) { + if n > 0 { + m.maxBodyCache = n + } + } +} + +// Check returns a Fiber middleware that enforces idempotency on mutating requests. +// If the Middleware is nil, a pass-through handler is returned. +func (m *Middleware) Check() fiber.Handler { + if m == nil { + return func(c *fiber.Ctx) error { + return c.Next() + } + } + + return m.handle +} + +func (m *Middleware) handle(c *fiber.Ctx) error { + // Idempotency only applies to mutating methods. + if c.Method() == fiber.MethodGet || c.Method() == fiber.MethodOptions || c.Method() == fiber.MethodHead { + return c.Next() + } + + idempotencyKey := c.Get(chttp.IdempotencyKey) + if idempotencyKey == "" { + return c.Next() + } + + if len(idempotencyKey) > m.maxKeyLength { + if m.onRejected != nil { + return m.onRejected(c) + } + + return c.Status(http.StatusBadRequest).JSON(fiber.Map{ + "code": "VALIDATION_ERROR", + "message": fmt.Sprintf("%s must not exceed %d characters", chttp.IdempotencyKey, m.maxKeyLength), + }) + } + + // Build a tenant-scoped Redis key for per-tenant isolation. + tenantID := tmcore.GetTenantIDContext(c.UserContext()) + key := fmt.Sprintf("%s%s:%s", m.keyPrefix, tenantID, idempotencyKey) + + ctx, cancel := context.WithTimeout(c.UserContext(), m.redisTimeout) + defer cancel() + + client, err := m.conn.GetClient(ctx) + if err != nil { + // Redis unavailable — fail-open to preserve availability. + m.logger.Log(ctx, log.LevelWarn, "idempotency: redis unavailable, failing open", log.Err(err)) + return c.Next() + } + + // SetNX atomically checks and sets — returns true only if the key was newly created. + set, setnxErr := client.SetNX(ctx, key, keyStateProcessing, m.keyTTL).Result() + if setnxErr != nil { + m.logger.Log(ctx, log.LevelWarn, "idempotency: setnx failed, failing open", log.Err(setnxErr)) + return c.Next() + } + + responseKey := key + ":response" + + if !set { + return m.handleDuplicate(ctx, c, client, key, responseKey) + } + + // Proceed with the actual handler. + handlerErr := c.Next() + + // Create fresh context for post-handler Redis bookkeeping. + // The pre-handler ctx may have expired during handler execution. + postCtx, postCancel := context.WithTimeout(context.WithoutCancel(c.UserContext()), m.redisTimeout) + defer postCancel() + + m.saveResult(postCtx, c, client, key, responseKey, handlerErr) + + return handlerErr +} + +// handleDuplicate processes a duplicate request (one whose idempotency key already exists +// in Redis). It attempts to replay the cached response when available, falls back to a +// conflict response when the original request is still in flight, or returns a generic +// "already processed" response when the key is complete but the body was not cached. +func (m *Middleware) handleDuplicate( + ctx context.Context, + c *fiber.Ctx, + client redis.UniversalClient, + key, responseKey string, +) error { + // Read the current key value to distinguish in-flight from completed. + keyValue, _ := client.Get(ctx, key).Result() + + // Try to replay the cached response (true idempotency). + cached, cacheErr := client.Get(ctx, responseKey).Result() + if cacheErr == nil && cached != "" { + var resp cachedResponse + if unmarshalErr := json.Unmarshal([]byte(cached), &resp); unmarshalErr == nil { + c.Set(chttp.IdempotencyReplayed, "true") + c.Set("Content-Type", resp.ContentType) + + return c.Status(resp.StatusCode).SendString(resp.Body) + } + } + + // No cached response available — differentiate by key state. + c.Set(chttp.IdempotencyReplayed, "true") + + if keyValue == keyStateProcessing { + // Request is still in flight — tell the client to retry later. + return c.Status(http.StatusConflict).JSON(fiber.Map{ + "code": "IDEMPOTENCY_CONFLICT", + "detail": "a request with this idempotency key is currently being processed", + }) + } + + // Key is "complete" but the response body was not cached + // (e.g., body exceeded maxBodyCache limit). + return c.Status(http.StatusOK).JSON(fiber.Map{ + "code": "IDEMPOTENT", + "detail": "request already processed", + }) +} + +// saveResult performs post-handler Redis bookkeeping: on success it caches the response +// body and marks the key as complete in a single round-trip via a Redis pipeline; on +// handler error it deletes both keys so the client can retry with the same idempotency key. +func (m *Middleware) saveResult( + ctx context.Context, + c *fiber.Ctx, + client redis.UniversalClient, + key, responseKey string, + handlerErr error, +) { + if handlerErr == nil { + body := c.Response().Body() + + pipe := client.Pipeline() + + if len(body) <= m.maxBodyCache { + resp := cachedResponse{ + StatusCode: c.Response().StatusCode(), + ContentType: string(c.Response().Header.ContentType()), + Body: string(body), + } + + if data, marshalErr := json.Marshal(resp); marshalErr == nil { + pipe.Set(ctx, responseKey, string(data), m.keyTTL) + } + } else { + m.logger.Log(ctx, log.LevelWarn, + "idempotency: response body exceeds maxBodyCache, skipping cache", + log.Int("body_size", len(body)), + log.Int("max_body_cache", m.maxBodyCache), + ) + } + + pipe.Set(ctx, key, keyStateComplete, m.keyTTL) + + if _, pipeErr := pipe.Exec(ctx); pipeErr != nil { + m.logger.Log(ctx, log.LevelWarn, + "idempotency: failed to atomically cache response and mark complete", + log.Err(pipeErr), + ) + } + } else { + pipe := client.Pipeline() + pipe.Del(ctx, key) + pipe.Del(ctx, responseKey) + + if _, pipeErr := pipe.Exec(ctx); pipeErr != nil { + m.logger.Log(ctx, log.LevelWarn, + "idempotency: failed to delete keys after handler error", + log.Err(pipeErr), + ) + } + } +} diff --git a/commons/net/http/idempotency/idempotency_test.go b/commons/net/http/idempotency/idempotency_test.go new file mode 100644 index 00000000..623863ed --- /dev/null +++ b/commons/net/http/idempotency/idempotency_test.go @@ -0,0 +1,823 @@ +//go:build unit + +package idempotency + +import ( + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + chttp "github.com/LerianStudio/lib-commons/v4/commons/constants" + libLog "github.com/LerianStudio/lib-commons/v4/commons/log" + libRedis "github.com/LerianStudio/lib-commons/v4/commons/redis" + tmcore "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core" + "github.com/alicebob/miniredis/v2" + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// newRedisClient creates a *libRedis.Client backed by a miniredis instance. +// The connection is closed automatically when the test finishes. +func newRedisClient(t *testing.T, mr *miniredis.Miniredis) *libRedis.Client { + t.Helper() + + conn, err := libRedis.New(t.Context(), libRedis.Config{ + Topology: libRedis.Topology{ + Standalone: &libRedis.StandaloneTopology{Address: mr.Addr()}, + }, + Logger: &libLog.NopLogger{}, + }) + require.NoError(t, err) + + t.Cleanup(func() { + if err := conn.Close(); err != nil { + t.Logf("redis close: %v", err) + } + }) + + return conn +} + +// newPostApp builds a Fiber app that routes POST /test through the given +// middleware, then calls a handler that writes 201 + JSON body. +// An optional pre-middleware is called before the idempotency middleware +// to let tests inject tenant context. +func newPostApp(mw fiber.Handler, preMiddleware ...fiber.Handler) *fiber.App { + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + + for _, pm := range preMiddleware { + app.Use(pm) + } + + app.Use(mw) + + app.Post("/test", func(c *fiber.Ctx) error { + return c.Status(fiber.StatusCreated).JSON(fiber.Map{"status": "created"}) + }) + + // Also register GET and OPTIONS for pass-through tests. + app.Get("/test", func(c *fiber.Ctx) error { + return c.SendString("ok-get") + }) + + app.Options("/test", func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusNoContent) + }) + + return app +} + +// tenantMiddleware returns a Fiber handler that injects tenantID into the +// request's user context via tmcore.ContextWithTenantID, mimicking real +// tenant-extraction middleware. +func tenantMiddleware(tenantID string) fiber.Handler { + return func(c *fiber.Ctx) error { + ctx := tmcore.ContextWithTenantID(c.UserContext(), tenantID) + c.SetUserContext(ctx) + + return c.Next() + } +} + +// doPost sends a POST /test with the given idempotency key header. +func doPost(t *testing.T, app *fiber.App, idempotencyKey string) *http.Response { + t.Helper() + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + if idempotencyKey != "" { + req.Header.Set(chttp.IdempotencyKey, idempotencyKey) + } + + resp, err := app.Test(req, -1) + require.NoError(t, err) + + return resp +} + +// readBody reads and returns the full response body, closing it. +func readBody(t *testing.T, resp *http.Response) string { + t.Helper() + + defer resp.Body.Close() + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + return string(b) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +func TestNew_NilConn(t *testing.T) { + t.Parallel() + + m := New(nil) + assert.Nil(t, m, "New(nil) must return nil middleware") +} + +func TestCheck_NilMiddleware(t *testing.T) { + t.Parallel() + + var m *Middleware // nil + + handler := m.Check() + require.NotNil(t, handler, "Check() on nil receiver must return a handler") + + // The handler must be a pass-through. + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app.Use(handler) + app.Post("/test", func(c *fiber.Ctx) error { + return c.Status(fiber.StatusCreated).JSON(fiber.Map{"ok": true}) + }) + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + req.Header.Set(chttp.IdempotencyKey, "some-key") + + resp, err := app.Test(req, -1) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusCreated, resp.StatusCode, + "nil middleware must pass through to the actual handler") +} + +func TestCheck_GET_PassesThrough(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + app := newPostApp(m.Check()) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set(chttp.IdempotencyKey, "should-be-ignored") + + resp, err := app.Test(req, -1) + require.NoError(t, err) + + body := readBody(t, resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "ok-get", body) +} + +func TestCheck_OPTIONS_PassesThrough(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + app := newPostApp(m.Check()) + + req := httptest.NewRequest(http.MethodOptions, "/test", nil) + req.Header.Set(chttp.IdempotencyKey, "should-be-ignored") + + resp, err := app.Test(req, -1) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusNoContent, resp.StatusCode) +} + +func TestCheck_NoHeader_PassesThrough(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + app := newPostApp(m.Check()) + + // POST without idempotency header — proceeds normally. + resp := doPost(t, app, "") + defer resp.Body.Close() + + assert.Equal(t, http.StatusCreated, resp.StatusCode) +} + +func TestCheck_KeyTooLong_Rejected(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn, WithMaxKeyLength(10)) + + app := newPostApp(m.Check()) + + longKey := strings.Repeat("x", 11) + resp := doPost(t, app, longKey) + body := readBody(t, resp) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.Contains(t, body, "VALIDATION_ERROR") + assert.Contains(t, body, chttp.IdempotencyKey) +} + +func TestCheck_KeyTooLong_CustomHandler(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + + m := New(conn, + WithMaxKeyLength(5), + WithRejectedHandler(func(c *fiber.Ctx) error { + return c.Status(http.StatusUnprocessableEntity).JSON(fiber.Map{ + "custom": "rejected", + }) + }), + ) + + app := newPostApp(m.Check()) + + longKey := strings.Repeat("k", 6) + resp := doPost(t, app, longKey) + body := readBody(t, resp) + + assert.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode) + assert.Contains(t, body, "rejected") +} + +func TestCheck_FirstRequest_Proceeds(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + app := newPostApp(m.Check(), tenantMiddleware("tenant-1")) + + resp := doPost(t, app, "unique-key-1") + body := readBody(t, resp) + + assert.Equal(t, http.StatusCreated, resp.StatusCode) + assert.Contains(t, body, "created") + + // Verify the response was cached in Redis. + keys := mr.Keys() + // Expect two keys: the lock key and the :response key. + assert.GreaterOrEqual(t, len(keys), 2, + "expected lock key + response key in Redis, got: %v", keys) +} + +func TestCheck_DuplicateRequest_ReplaysResponse(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + app := newPostApp(m.Check(), tenantMiddleware("tenant-dup")) + + // First request — proceeds normally. + resp1 := doPost(t, app, "dup-key") + body1 := readBody(t, resp1) + assert.Equal(t, http.StatusCreated, resp1.StatusCode) + assert.Contains(t, body1, "created") + + // Second request — same key — must replay. + resp2 := doPost(t, app, "dup-key") + body2 := readBody(t, resp2) + + assert.Equal(t, http.StatusCreated, resp2.StatusCode, + "replayed response must have the original status code") + assert.Contains(t, body2, "created", + "replayed response must have the original body") + assert.Equal(t, "true", resp2.Header.Get(chttp.IdempotencyReplayed), + "replayed response must set X-Idempotency-Replayed: true") +} + +func TestCheck_DuplicateRequest_StillProcessing(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + // Simulate a first request that is "still processing" by setting the lock + // key directly in Redis without a :response companion. + tenantID := "tenant-proc" + idempotencyKey := "processing-key" + lockKey := "idempotency:" + tenantID + ":" + idempotencyKey + + require.NoError(t, mr.Set(lockKey, "processing")) + mr.SetTTL(lockKey, 7*24*time.Hour) + + app := newPostApp(m.Check(), tenantMiddleware(tenantID)) + + resp := doPost(t, app, idempotencyKey) + body := readBody(t, resp) + + // The current production code (idempotency.go) returns 409 Conflict when the + // key is in "processing" state — the request is still in-flight. The generic + // 200 IDEMPOTENT response is only returned when the key is "complete" but the + // response body was not cached (e.g., body exceeded maxBodyCache). + assert.Equal(t, http.StatusConflict, resp.StatusCode) + assert.Contains(t, body, "IDEMPOTENCY_CONFLICT") + assert.Equal(t, "true", resp.Header.Get(chttp.IdempotencyReplayed)) +} + +func TestCheck_FailedRequest_KeyDeleted(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + handlerErr := errors.New("handler boom") + + // Build a custom app whose handler returns an error. + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app.Use(tenantMiddleware("tenant-fail")) + app.Use(m.Check()) + app.Post("/test", func(_ *fiber.Ctx) error { + return handlerErr + }) + + resp := doPost(t, app, "fail-key") + defer resp.Body.Close() + + // Fiber translates an unhandled error to 500. + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + + // The lock key must have been deleted so the client can retry. + keys := mr.Keys() + assert.Empty(t, keys, "all keys must be deleted after handler failure, got: %v", keys) +} + +func TestCheck_TenantIsolation(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + sharedKey := "same-idem-key" + + // Tenant A — first request proceeds. + appA := newPostApp(m.Check(), tenantMiddleware("tenant-A")) + respA := doPost(t, appA, sharedKey) + bodyA := readBody(t, respA) + assert.Equal(t, http.StatusCreated, respA.StatusCode) + assert.Contains(t, bodyA, "created") + + // Tenant B — same idempotency key, different tenant — must also proceed. + appB := newPostApp(m.Check(), tenantMiddleware("tenant-B")) + respB := doPost(t, appB, sharedKey) + bodyB := readBody(t, respB) + assert.Equal(t, http.StatusCreated, respB.StatusCode, + "same key for a different tenant must proceed independently") + assert.Contains(t, bodyB, "created") + + // Tenant A — duplicate of same key — must replay. + respA2 := doPost(t, appA, sharedKey) + assert.Equal(t, "true", respA2.Header.Get(chttp.IdempotencyReplayed), + "same key + same tenant must replay") + respA2.Body.Close() +} + +func TestOptions_Defaults(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + require.NotNil(t, m) + assert.Equal(t, "idempotency:", m.keyPrefix) + assert.Equal(t, 7*24*time.Hour, m.keyTTL) + assert.Equal(t, 256, m.maxKeyLength) + assert.Equal(t, 500*time.Millisecond, m.redisTimeout) + assert.Nil(t, m.onRejected, "default rejected handler should be nil (use built-in)") +} + +// --------------------------------------------------------------------------- +// Option application tests +// --------------------------------------------------------------------------- + +func TestOptions_Custom(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts []Option + checkFn func(t *testing.T, m *Middleware) + }{ + { + name: "WithKeyPrefix", + opts: []Option{WithKeyPrefix("custom:")}, + checkFn: func(t *testing.T, m *Middleware) { + t.Helper() + assert.Equal(t, "custom:", m.keyPrefix) + }, + }, + { + name: "WithKeyPrefix empty ignored", + opts: []Option{WithKeyPrefix("")}, + checkFn: func(t *testing.T, m *Middleware) { + t.Helper() + assert.Equal(t, "idempotency:", m.keyPrefix, "empty prefix must be ignored") + }, + }, + { + name: "WithKeyTTL", + opts: []Option{WithKeyTTL(1 * time.Hour)}, + checkFn: func(t *testing.T, m *Middleware) { + t.Helper() + assert.Equal(t, 1*time.Hour, m.keyTTL) + }, + }, + { + name: "WithKeyTTL zero ignored", + opts: []Option{WithKeyTTL(0)}, + checkFn: func(t *testing.T, m *Middleware) { + t.Helper() + assert.Equal(t, 7*24*time.Hour, m.keyTTL, "zero TTL must be ignored") + }, + }, + { + name: "WithMaxKeyLength", + opts: []Option{WithMaxKeyLength(64)}, + checkFn: func(t *testing.T, m *Middleware) { + t.Helper() + assert.Equal(t, 64, m.maxKeyLength) + }, + }, + { + name: "WithMaxKeyLength zero ignored", + opts: []Option{WithMaxKeyLength(0)}, + checkFn: func(t *testing.T, m *Middleware) { + t.Helper() + assert.Equal(t, 256, m.maxKeyLength, "zero maxKeyLength must be ignored") + }, + }, + { + name: "WithRedisTimeout", + opts: []Option{WithRedisTimeout(2 * time.Second)}, + checkFn: func(t *testing.T, m *Middleware) { + t.Helper() + assert.Equal(t, 2*time.Second, m.redisTimeout) + }, + }, + { + name: "WithRedisTimeout zero ignored", + opts: []Option{WithRedisTimeout(0)}, + checkFn: func(t *testing.T, m *Middleware) { + t.Helper() + assert.Equal(t, 500*time.Millisecond, m.redisTimeout, "zero timeout must be ignored") + }, + }, + { + name: "WithLogger nil ignored", + opts: []Option{WithLogger(nil)}, + checkFn: func(t *testing.T, m *Middleware) { + t.Helper() + assert.NotNil(t, m.logger, "nil logger must keep the default nop logger") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn, tt.opts...) + require.NotNil(t, m) + + tt.checkFn(t, m) + }) + } +} + +// --------------------------------------------------------------------------- +// Redis failure — fail-open behavior +// --------------------------------------------------------------------------- + +func TestCheck_RedisDown_FailsOpen(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + app := newPostApp(m.Check(), tenantMiddleware("tenant-failopen")) + + // Kill Redis before the request. + mr.Close() + + resp := doPost(t, app, "key-while-redis-down") + defer resp.Body.Close() + + // fail-open: handler proceeds despite Redis being unreachable. + assert.Equal(t, http.StatusCreated, resp.StatusCode, + "must fail open when Redis is unavailable") +} + +// --------------------------------------------------------------------------- +// Verify that the response key uses correct prefix +// --------------------------------------------------------------------------- + +func TestCheck_RedisKeyFormat(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn, WithKeyPrefix("idem:")) + + app := newPostApp(m.Check(), tenantMiddleware("t1")) + + resp := doPost(t, app, "my-key") + resp.Body.Close() + + keys := mr.Keys() + require.Len(t, keys, 2, "expected lock + response keys") + + // Verify the key format: prefix + tenantID + idempotency key. + foundLock := false + foundResp := false + + for _, k := range keys { + if k == "idem:t1:my-key" { + foundLock = true + } + + if k == "idem:t1:my-key:response" { + foundResp = true + } + } + + assert.True(t, foundLock, "lock key must match expected format, got: %v", keys) + assert.True(t, foundResp, "response key must match expected format, got: %v", keys) +} + +// --------------------------------------------------------------------------- +// Concurrent same-key requests +// --------------------------------------------------------------------------- + +// TestCheck_ConcurrentSameKey launches 10 goroutines all POST-ing with the +// same idempotency key simultaneously. Exactly 1 should reach the upstream +// handler (get 201), while the rest receive either the cached 201 replay +// (Idempotency-Replayed: true) or a 409 IDEMPOTENCY_CONFLICT (in-flight). +func TestCheck_ConcurrentSameKey(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + // Build a Fiber app and start it on a real listener so many goroutines can + // hit it concurrently — app.Test() serialises internally. + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app.Use(tenantMiddleware("tenant-conc")) + app.Use(m.Check()) + app.Post("/test", func(c *fiber.Ctx) error { + return c.Status(fiber.StatusCreated).JSON(fiber.Map{"status": "created"}) + }) + + ln, listenErr := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, listenErr) + + go func() { _ = app.Listener(ln) }() + t.Cleanup(func() { _ = app.Shutdown() }) + + addr := ln.Addr().String() + + const goroutines = 10 + + type result struct { + status int + replayed string + body string + } + + results := make([]result, goroutines) + + var wg sync.WaitGroup + + for i := range goroutines { + wg.Add(1) + + go func(idx int) { + defer wg.Done() + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://"+addr+"/test", nil) + if err != nil { + return + } + + req.Header.Set(chttp.IdempotencyKey, "shared-concurrent-key") + + resp, doErr := http.DefaultClient.Do(req) + if doErr != nil { + return + } + + defer resp.Body.Close() + + b, _ := io.ReadAll(resp.Body) + results[idx] = result{ + status: resp.StatusCode, + replayed: resp.Header.Get(chttp.IdempotencyReplayed), + body: string(b), + } + }(i) + } + + wg.Wait() + + // Count how many got the original 201 without the replayed header. + // The rest must be 201 replays (Idempotency-Replayed: true) or 409 in-flight. + originals := 0 + + for _, r := range results { + if r.status == http.StatusCreated && r.replayed == "" { + originals++ + } else { + // Must be either a replay (201+replayed header) or in-flight (409). + ok := (r.status == http.StatusCreated && r.replayed == "true") || + r.status == http.StatusConflict + assert.True(t, ok, + "expected 201-replay or 409, got status=%d replayed=%q body=%s", + r.status, r.replayed, r.body) + } + } + + assert.Equal(t, 1, originals, + "exactly one goroutine must receive the original 201 from the handler") +} + +// --------------------------------------------------------------------------- +// Max body cache limit — oversized response falls back to IDEMPOTENT reply +// --------------------------------------------------------------------------- + +// TestCheck_WithMaxBodyCache verifies that when a response body exceeds the +// configured maxBodyCache, the first request still succeeds (201) but a +// duplicate request receives the generic "IDEMPOTENT" 200 response (the body +// was too large to cache). +func TestCheck_WithMaxBodyCache(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + + // Allow only 10 bytes of body cache — our handler returns ~35 bytes. + m := New(conn, WithMaxBodyCache(10)) + + // Handler returns a body clearly larger than 10 bytes. + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app.Use(tenantMiddleware("tenant-maxcache")) + app.Use(m.Check()) + app.Post("/test", func(c *fiber.Ctx) error { + return c.Status(fiber.StatusCreated).JSON(fiber.Map{"result": "ok", "extra": "padding-to-exceed-limit"}) + }) + + // First request — proceeds to the handler. + req1 := httptest.NewRequest(http.MethodPost, "/test", nil) + req1.Header.Set(chttp.IdempotencyKey, "big-body-key") + + resp1, err := app.Test(req1, -1) + require.NoError(t, err) + + defer resp1.Body.Close() + + assert.Equal(t, http.StatusCreated, resp1.StatusCode, "first request must reach the handler") + + // Second request — same key. Body was not cached (too large), so must get + // the generic IDEMPOTENT fallback, not the original 201 body. + req2 := httptest.NewRequest(http.MethodPost, "/test", nil) + req2.Header.Set(chttp.IdempotencyKey, "big-body-key") + + resp2, err := app.Test(req2, -1) + require.NoError(t, err) + + body2 := readBody(t, resp2) + + assert.Equal(t, http.StatusOK, resp2.StatusCode, + "duplicate request with uncached body must get the generic 200 IDEMPOTENT response") + assert.Contains(t, body2, "IDEMPOTENT") + assert.Equal(t, "true", resp2.Header.Get(chttp.IdempotencyReplayed)) +} + +// --------------------------------------------------------------------------- +// In-flight detection — 409 Conflict +// --------------------------------------------------------------------------- + +// TestCheck_InFlight_Returns409 verifies that while a request is being processed +// (key is in "processing" state with no cached response), a duplicate request +// receives 409 IDEMPOTENCY_CONFLICT. Then the first request completes normally. +func TestCheck_InFlight_Returns409(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + // Pre-set the Redis key to "processing" without a response key — this + // simulates a first request that is currently in-flight. + tenantID := "tenant-inflight" + idempotencyKey := "inflight-key" + lockKey := "idempotency:" + tenantID + ":" + idempotencyKey + + require.NoError(t, mr.Set(lockKey, keyStateProcessing)) + mr.SetTTL(lockKey, 7*24*time.Hour) + + // Build an app and send a duplicate — no response key exists. + app := newPostApp(m.Check(), tenantMiddleware(tenantID)) + + resp := doPost(t, app, idempotencyKey) + body := readBody(t, resp) + + assert.Equal(t, http.StatusConflict, resp.StatusCode, + "duplicate of an in-flight request must return 409 Conflict") + assert.Contains(t, body, "IDEMPOTENCY_CONFLICT", + "response body must contain IDEMPOTENCY_CONFLICT code") + assert.Equal(t, "true", resp.Header.Get(chttp.IdempotencyReplayed)) +} + +// --------------------------------------------------------------------------- +// HEAD request passthrough +// --------------------------------------------------------------------------- + +// TestCheck_HeadRequest_PassesThrough verifies that HEAD requests bypass the +// idempotency middleware even when an Idempotency-Key header is present. +func TestCheck_HeadRequest_PassesThrough(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + m := New(conn) + + var handlerCalled atomic.Bool + + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app.Use(m.Check()) + app.Head("/test", func(c *fiber.Ctx) error { + handlerCalled.Store(true) + return c.SendStatus(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodHead, "/test", nil) + req.Header.Set(chttp.IdempotencyKey, "head-key-should-be-ignored") + + resp, err := app.Test(req, -1) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode, + "HEAD request must reach the handler directly") + assert.True(t, handlerCalled.Load(), + "handler must be called for HEAD requests (idempotency middleware is bypassed)") + + // The idempotency middleware must not have set the replayed header. + assert.Empty(t, resp.Header.Get(chttp.IdempotencyReplayed), + "HEAD pass-through must not set the Idempotency-Replayed header") + + // Verify that no Redis keys were written — HEAD bypasses the whole flow. + assert.Empty(t, mr.Keys(), "HEAD pass-through must not write any Redis keys") +} + +// --------------------------------------------------------------------------- +// Negative option values — defaults must be preserved +// --------------------------------------------------------------------------- + +// TestOptions_NegativeValues verifies that negative values passed to option +// constructors are treated as invalid and the configured defaults are preserved. +func TestOptions_NegativeValues(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + conn := newRedisClient(t, mr) + + m := New(conn, + WithMaxKeyLength(-1), + WithKeyTTL(-1*time.Hour), + WithRedisTimeout(-1*time.Second), + ) + + require.NotNil(t, m) + + assert.Equal(t, 256, m.maxKeyLength, + "negative maxKeyLength must be ignored; default (256) must be preserved") + assert.Equal(t, 7*24*time.Hour, m.keyTTL, + "negative keyTTL must be ignored; default (7 days) must be preserved") + assert.Equal(t, 500*time.Millisecond, m.redisTimeout, + "negative redisTimeout must be ignored; default (500ms) must be preserved") +} diff --git a/commons/systemplane/adapters/http/fiber/dto.go b/commons/systemplane/adapters/http/fiber/dto.go index 2df53020..2fdbc885 100644 --- a/commons/systemplane/adapters/http/fiber/dto.go +++ b/commons/systemplane/adapters/http/fiber/dto.go @@ -52,6 +52,7 @@ type SchemaResponse struct { // SchemaEntryDTO represents a single key's metadata in the schema response. type SchemaEntryDTO struct { Key string `json:"key"` + EnvVar string `json:"envVar,omitempty"` Kind string `json:"kind"` AllowedScopes []string `json:"allowedScopes"` ValueType string `json:"valueType"` @@ -59,6 +60,7 @@ type SchemaEntryDTO struct { MutableAtRuntime bool `json:"mutableAtRuntime"` ApplyBehavior string `json:"applyBehavior"` Secret bool `json:"secret"` + RedactPolicy string `json:"redactPolicy"` Description string `json:"description"` Group string `json:"group"` } @@ -108,7 +110,6 @@ type ReloadResponse struct { // Conversion helpers. const ( - redactedPlaceholder = "********" defaultHistoryLimit = 50 maxHistoryLimit = 100 revisionQuoteChar = "\"" @@ -197,6 +198,7 @@ func toSchemaResponse(entries []service.SchemaEntry) SchemaResponse { dtos[i] = SchemaEntryDTO{ Key: entry.Key, + EnvVar: entry.EnvVar, Kind: string(entry.Kind), AllowedScopes: scopes, ValueType: string(entry.ValueType), @@ -204,6 +206,7 @@ func toSchemaResponse(entries []service.SchemaEntry) SchemaResponse { MutableAtRuntime: entry.MutableAtRuntime, ApplyBehavior: string(entry.ApplyBehavior), Secret: entry.Secret, + RedactPolicy: string(entry.RedactPolicy), Description: entry.Description, Group: entry.Group, } diff --git a/commons/systemplane/adapters/http/fiber/dto_test.go b/commons/systemplane/adapters/http/fiber/dto_test.go index c04f608b..094f3260 100644 --- a/commons/systemplane/adapters/http/fiber/dto_test.go +++ b/commons/systemplane/adapters/http/fiber/dto_test.go @@ -116,7 +116,7 @@ func TestToEffectiveValueDTO_Redacted(t *testing.T) { // layer are already masked. The DTO preserves them without re-masking. ev := domain.EffectiveValue{ Key: "db.password", - Value: redactedPlaceholder, + Value: "********", Default: "default-secret", Source: "env", Redacted: true, @@ -125,7 +125,7 @@ func TestToEffectiveValueDTO_Redacted(t *testing.T) { dto := toEffectiveValueDTO(ev) assert.Equal(t, "db.password", dto.Key) - assert.Equal(t, redactedPlaceholder, dto.Value) + assert.Equal(t, "********", dto.Value) assert.Equal(t, "default-secret", dto.Default) assert.True(t, dto.Redacted) } @@ -168,6 +168,7 @@ func TestToSchemaResponse(t *testing.T) { entries := []service.SchemaEntry{ { Key: "test.key", + EnvVar: "TEST_KEY", Kind: domain.KindConfig, AllowedScopes: []domain.Scope{domain.ScopeGlobal, domain.ScopeTenant}, ValueType: domain.ValueTypeString, @@ -175,6 +176,7 @@ func TestToSchemaResponse(t *testing.T) { MutableAtRuntime: true, ApplyBehavior: domain.ApplyLiveRead, Secret: true, + RedactPolicy: domain.RedactFull, Description: "A test key", Group: "test", }, @@ -184,16 +186,33 @@ func TestToSchemaResponse(t *testing.T) { require.Len(t, resp.Keys, 1) assert.Equal(t, "test.key", resp.Keys[0].Key) + assert.Equal(t, "TEST_KEY", resp.Keys[0].EnvVar) assert.Equal(t, "config", resp.Keys[0].Kind) assert.Equal(t, []string{"global", "tenant"}, resp.Keys[0].AllowedScopes) assert.Equal(t, "string", resp.Keys[0].ValueType) assert.True(t, resp.Keys[0].MutableAtRuntime) assert.Equal(t, "live-read", resp.Keys[0].ApplyBehavior) assert.True(t, resp.Keys[0].Secret) + assert.Equal(t, "full", resp.Keys[0].RedactPolicy) assert.Equal(t, "A test key", resp.Keys[0].Description) assert.Equal(t, "test", resp.Keys[0].Group) } +func TestToSchemaResponse_DefaultRedactPolicyIsExplicit(t *testing.T) { + t.Parallel() + + resp := toSchemaResponse([]service.SchemaEntry{{ + Key: "test.key", + Kind: domain.KindConfig, + AllowedScopes: []domain.Scope{domain.ScopeGlobal}, + ValueType: domain.ValueTypeString, + RedactPolicy: domain.RedactNone, + }}) + + require.Len(t, resp.Keys, 1) + assert.Equal(t, "none", resp.Keys[0].RedactPolicy) +} + func TestParseHistoryFilter_Defaults(t *testing.T) { t.Parallel() diff --git a/commons/systemplane/bootstrap/apply_keydefs_test.go b/commons/systemplane/bootstrap/apply_keydefs_test.go new file mode 100644 index 00000000..0a5cd810 --- /dev/null +++ b/commons/systemplane/bootstrap/apply_keydefs_test.go @@ -0,0 +1,245 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package bootstrap + +import ( + "testing" + + "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplyKeyDefs_SetsApplyBehaviors(t *testing.T) { + t.Parallel() + + cfg := &BootstrapConfig{ + Backend: domain.BackendPostgres, + Postgres: &PostgresBootstrapConfig{DSN: "postgres://localhost/db"}, + } + + defs := []domain.KeyDef{ + {Key: "app.log_level", ApplyBehavior: domain.ApplyLiveRead}, + {Key: "postgres.max_open_conns", ApplyBehavior: domain.ApplyLiveRead}, + {Key: "bacen_spi.timeout_sec", ApplyBehavior: domain.ApplyBundleRebuild}, + } + + cfg.ApplyKeyDefs(defs) + + require.Len(t, cfg.ApplyBehaviors, 3) + assert.Equal(t, domain.ApplyLiveRead, cfg.ApplyBehaviors["app.log_level"]) + assert.Equal(t, domain.ApplyBundleRebuild, cfg.ApplyBehaviors["bacen_spi.timeout_sec"]) + + // Postgres sub-config gets the same map. + require.NotNil(t, cfg.Postgres.ApplyBehaviors) + assert.Equal(t, domain.ApplyLiveRead, cfg.Postgres.ApplyBehaviors["postgres.max_open_conns"]) +} + +func TestApplyKeyDefs_ConfiguresSecretsWhenMasterKeySet(t *testing.T) { + t.Setenv(EnvSecretMasterKey, "test-master-key-32-chars-long!!!") + + cfg := &BootstrapConfig{ + Backend: domain.BackendPostgres, + Postgres: &PostgresBootstrapConfig{DSN: "postgres://localhost/db"}, + } + + defs := []domain.KeyDef{ + {Key: "auth.client_secret", ApplyBehavior: domain.ApplyBootstrapOnly, Secret: true}, + {Key: "app.log_level", ApplyBehavior: domain.ApplyLiveRead}, + } + + cfg.ApplyKeyDefs(defs) + + require.NotNil(t, cfg.Secrets) + assert.Equal(t, "test-master-key-32-chars-long!!!", cfg.Secrets.MasterKey) + assert.Equal(t, []string{"auth.client_secret"}, cfg.Secrets.SecretKeys) +} + +func TestApplyKeyDefs_NoSecretsWithoutMasterKeyOrSecretKeys(t *testing.T) { + t.Parallel() + + cfg := &BootstrapConfig{ + Backend: domain.BackendPostgres, + Postgres: &PostgresBootstrapConfig{DSN: "postgres://localhost/db"}, + } + + defs := []domain.KeyDef{ + {Key: "app.log_level", ApplyBehavior: domain.ApplyLiveRead}, + } + + cfg.ApplyKeyDefs(defs) + + assert.Nil(t, cfg.Secrets) +} + +func TestApplyKeyDefs_NilConfig(t *testing.T) { + t.Parallel() + + // Must not panic — verified explicitly. + var cfg *BootstrapConfig + assert.NotPanics(t, func() { + cfg.ApplyKeyDefs([]domain.KeyDef{{Key: "k", ApplyBehavior: domain.ApplyLiveRead}}) + }) +} + +func TestApplyKeyDefs_EmptySlice(t *testing.T) { + t.Parallel() + + cfg := &BootstrapConfig{ + Backend: domain.BackendPostgres, + Postgres: &PostgresBootstrapConfig{DSN: "postgres://localhost/db"}, + } + + cfg.ApplyKeyDefs([]domain.KeyDef{}) + + // ApplyBehaviors should be an initialized (non-nil) empty map, not nil. + require.NotNil(t, cfg.ApplyBehaviors) + assert.Empty(t, cfg.ApplyBehaviors) + // No secret keys → Secrets must remain nil. + assert.Nil(t, cfg.Secrets) +} + +func TestApplyKeyDefs_SecretKeysButNoMasterKey(t *testing.T) { + // Cannot use t.Parallel() because t.Setenv requires a sequential test. + + // Ensure the env var is NOT set so we exercise the "no master key" branch. + t.Setenv(EnvSecretMasterKey, "") + + cfg := &BootstrapConfig{ + Backend: domain.BackendPostgres, + Postgres: &PostgresBootstrapConfig{DSN: "postgres://localhost/db"}, + } + + defs := []domain.KeyDef{ + {Key: "auth.client_secret", ApplyBehavior: domain.ApplyBootstrapOnly, Secret: true}, + } + + cfg.ApplyKeyDefs(defs) + + // Even though there are secret keys, no master key → Secrets stays nil. + assert.Nil(t, cfg.Secrets) + // But ApplyBehaviors is still populated. + require.Len(t, cfg.ApplyBehaviors, 1) + assert.Equal(t, domain.ApplyBootstrapOnly, cfg.ApplyBehaviors["auth.client_secret"]) +} + +func TestApplyKeyDefs_MapIndependence(t *testing.T) { + t.Parallel() + + cfg := &BootstrapConfig{ + Backend: domain.BackendPostgres, + Postgres: &PostgresBootstrapConfig{DSN: "postgres://localhost/db"}, + } + + defs := []domain.KeyDef{ + {Key: "app.log_level", ApplyBehavior: domain.ApplyLiveRead}, + {Key: "postgres.max_open_conns", ApplyBehavior: domain.ApplyLiveRead}, + } + + cfg.ApplyKeyDefs(defs) + + // Mutate the Postgres sub-config map… + cfg.Postgres.ApplyBehaviors["injected.key"] = domain.ApplyBundleRebuild + + // …and verify the top-level map is NOT affected. + _, exists := cfg.ApplyBehaviors["injected.key"] + assert.False(t, exists, "top-level ApplyBehaviors must be independent of Postgres sub-config map") + + // Also verify the reverse: mutating the top-level map does not bleed into Postgres. + cfg.ApplyBehaviors["another.key"] = domain.ApplyBootstrapOnly + _, existsInPg := cfg.Postgres.ApplyBehaviors["another.key"] + assert.False(t, existsInPg, "Postgres ApplyBehaviors must be independent of top-level map") +} + +func TestSecretStoreConfig_String(t *testing.T) { + t.Parallel() + + t.Run("non-nil config redacts master key", func(t *testing.T) { + t.Parallel() + + s := &SecretStoreConfig{MasterKey: "super-secret-key-32-bytes-here!!", SecretKeys: []string{"auth.client_secret"}} + assert.Equal(t, "SecretStoreConfig{MasterKey:REDACTED}", s.String()) + }) + + t.Run("nil receiver returns ", func(t *testing.T) { + t.Parallel() + + var s *SecretStoreConfig + assert.Equal(t, "", s.String()) + }) + + t.Run("GoString delegates to String", func(t *testing.T) { + t.Parallel() + + s := &SecretStoreConfig{MasterKey: "should-not-appear"} + assert.Equal(t, s.String(), s.GoString()) + }) + + t.Run("nil GoString delegates to String", func(t *testing.T) { + t.Parallel() + + var s *SecretStoreConfig + assert.Equal(t, s.String(), s.GoString()) + }) +} + +func TestValidate_WeakMasterKey(t *testing.T) { + t.Parallel() + + cfg := &BootstrapConfig{ + Backend: domain.BackendPostgres, + Postgres: &PostgresBootstrapConfig{ + DSN: "postgres://localhost/db", + }, + Secrets: &SecretStoreConfig{ + MasterKey: "tooshort10", // 10 bytes — well below the 32-byte minimum + SecretKeys: []string{"auth.client_secret"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "32") +} + +func TestValidate_ValidMasterKey_Raw32Bytes(t *testing.T) { + t.Parallel() + + // Exactly 32 ASCII bytes — the minimum accepted raw form. + masterKey := "12345678901234567890123456789012" + require.Len(t, []byte(masterKey), 32, "test precondition: key must be exactly 32 bytes") + + cfg := &BootstrapConfig{ + Backend: domain.BackendPostgres, + Postgres: &PostgresBootstrapConfig{ + DSN: "postgres://localhost/db", + }, + Secrets: &SecretStoreConfig{ + MasterKey: masterKey, + SecretKeys: []string{"auth.client_secret"}, + }, + } + + err := cfg.Validate() + assert.NoError(t, err) +} + +func TestApplyKeyDefs_MongoDBSubConfig(t *testing.T) { + t.Parallel() + + cfg := &BootstrapConfig{ + Backend: domain.BackendMongoDB, + MongoDB: &MongoBootstrapConfig{URI: "mongodb://localhost"}, + } + + defs := []domain.KeyDef{ + {Key: "app.log_level", ApplyBehavior: domain.ApplyLiveRead}, + } + + cfg.ApplyKeyDefs(defs) + + require.NotNil(t, cfg.MongoDB.ApplyBehaviors) + assert.Equal(t, domain.ApplyLiveRead, cfg.MongoDB.ApplyBehaviors["app.log_level"]) +} diff --git a/commons/systemplane/bootstrap/backend.go b/commons/systemplane/bootstrap/backend.go index 8fc39786..9ac636f2 100644 --- a/commons/systemplane/bootstrap/backend.go +++ b/commons/systemplane/bootstrap/backend.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" "io" + "maps" + "sync" "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" "github.com/LerianStudio/lib-commons/v4/commons/systemplane/ports" @@ -26,10 +28,67 @@ type BackendResources struct { // bootstrap and adapter packages. type BackendFactory func(ctx context.Context, cfg *BootstrapConfig) (*BackendResources, error) -// backendFactories maps each supported BackendKind to its constructor. -// Entries are registered via RegisterBackendFactory at init time by the -// adapter packages (or by the wiring code that imports them). -var backendFactories = map[domain.BackendKind]BackendFactory{} +type backendRegistryState struct { + mu sync.RWMutex + factories map[domain.BackendKind]BackendFactory + initErrors []error +} + +func newBackendRegistryState() *backendRegistryState { + return &backendRegistryState{factories: map[domain.BackendKind]BackendFactory{}} +} + +func (s *backendRegistryState) reset() { + s.mu.Lock() + defer s.mu.Unlock() + + s.factories = map[domain.BackendKind]BackendFactory{} + s.initErrors = nil +} + +func (s *backendRegistryState) recordInitError(err error) { + if domain.IsNilValue(err) { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.initErrors = append(s.initErrors, err) +} + +func (s *backendRegistryState) register(kind domain.BackendKind, factory BackendFactory) error { + if !kind.IsValid() { + return fmt.Errorf("%w %q", errInvalidBackendKind, kind) + } + + if factory == nil { + return errNilBackendFactory + } + + s.mu.Lock() + defer s.mu.Unlock() + + if _, exists := s.factories[kind]; exists { + return fmt.Errorf("%w %q", errBackendAlreadyRegistered, kind) + } + + s.factories[kind] = factory + + return nil +} + +func (s *backendRegistryState) snapshot() (map[domain.BackendKind]BackendFactory, []error) { + s.mu.RLock() + defer s.mu.RUnlock() + + factories := make(map[domain.BackendKind]BackendFactory, len(s.factories)) + maps.Copy(factories, s.factories) + + return factories, append([]error(nil), s.initErrors...) +} + +var backendRegistry = newBackendRegistryState() var ( errNilBackendConfig = errors.New("bootstrap backend: config is nil") @@ -48,20 +107,14 @@ var ( // errors. This function exists only for test isolation and must not be called // in production code. func ResetBackendFactories() { - backendFactories = map[domain.BackendKind]BackendFactory{} - initErrors = nil + backendRegistry.reset() } -// initErrors collects errors from backend factory registrations that occur -// during init(). These are checked lazily in NewBackendFromConfig so that -// registration failures surface as actionable errors instead of panics. -var initErrors []error - // RecordInitError appends an error to the package-level init error list. // It is intended to be called from init() functions in adapter packages when // RegisterBackendFactory fails. func RecordInitError(err error) { - initErrors = append(initErrors, err) + backendRegistry.recordInitError(err) } // RegisterBackendFactory registers a BackendFactory for the given backend kind. @@ -71,21 +124,7 @@ func RecordInitError(err error) { // Registration is single-write per backend kind. Duplicate or nil // registrations are rejected to preserve bootstrap integrity. func RegisterBackendFactory(kind domain.BackendKind, factory BackendFactory) error { - if !kind.IsValid() { - return fmt.Errorf("%w %q", errInvalidBackendKind, kind) - } - - if factory == nil { - return errNilBackendFactory - } - - if _, exists := backendFactories[kind]; exists { - return fmt.Errorf("%w %q", errBackendAlreadyRegistered, kind) - } - - backendFactories[kind] = factory - - return nil + return backendRegistry.register(kind, factory) } // NewBackendFromConfig creates the backend family based on the configured @@ -95,6 +134,7 @@ func RegisterBackendFactory(kind domain.BackendKind, factory BackendFactory) err // The caller is responsible for calling Closer.Close when the resources are no // longer needed. Callers typically defer res.Closer.Close(). func NewBackendFromConfig(ctx context.Context, cfg *BootstrapConfig) (*BackendResources, error) { + factories, initErrors := backendRegistry.snapshot() if len(initErrors) > 0 { return nil, fmt.Errorf("bootstrap backend: init registration errors: %w", errors.Join(initErrors...)) } @@ -109,7 +149,7 @@ func NewBackendFromConfig(ctx context.Context, cfg *BootstrapConfig) (*BackendRe return nil, fmt.Errorf("bootstrap backend: %w", err) } - factory, ok := backendFactories[cfg.Backend] + factory, ok := factories[cfg.Backend] if !ok { return nil, fmt.Errorf("%w %q (no factory registered)", errUnsupportedBackend, cfg.Backend) } @@ -119,25 +159,33 @@ func NewBackendFromConfig(ctx context.Context, cfg *BootstrapConfig) (*BackendRe return nil, err } + if err := validateBackendResources(resources); err != nil { + return nil, fmt.Errorf("bootstrap backend %q: %w", cfg.Backend, err) + } + + return resources, nil +} + +func validateBackendResources(resources *BackendResources) error { if resources == nil { - return nil, errNilBackendResources + return errNilBackendResources } if domain.IsNilValue(resources.Store) { - return nil, errNilBackendStore + return errNilBackendStore } if domain.IsNilValue(resources.History) { - return nil, errNilBackendHistoryStore + return errNilBackendHistoryStore } if domain.IsNilValue(resources.ChangeFeed) { - return nil, errNilBackendChangeFeed + return errNilBackendChangeFeed } if domain.IsNilValue(resources.Closer) { - return nil, errNilBackendCloser + return errNilBackendCloser } - return resources, nil + return nil } diff --git a/commons/systemplane/bootstrap/backend_test.go b/commons/systemplane/bootstrap/backend_test.go index b4f1abbb..05b907a3 100644 --- a/commons/systemplane/bootstrap/backend_test.go +++ b/commons/systemplane/bootstrap/backend_test.go @@ -2,6 +2,12 @@ // Copyright 2025 Lerian Studio. +// Tests in this file mutate the global backendRegistry (backendFactories and +// initErrors) to inject stubs and simulate error paths. They must NOT use +// t.Parallel() at the top level to avoid data races on shared global state. +// Individual sub-tests that only read from the registry may be safe to +// parallelize, but the parent test functions are intentionally serial. + package bootstrap import ( @@ -49,6 +55,39 @@ type noopCloser struct{} func (noopCloser) Close() error { return nil } +// withRegistrySnapshot saves the global backendRegistry state and restores it +// via t.Cleanup so each test starts with a known state. +func withRegistrySnapshot(t *testing.T) { + t.Helper() + + origFactories, origErrors := backendRegistry.snapshot() + + t.Cleanup(func() { + backendRegistry.mu.Lock() + defer backendRegistry.mu.Unlock() + + backendRegistry.factories = origFactories + backendRegistry.initErrors = origErrors + }) +} + +// registryDeleteLocked removes a factory from the registry under the mutex, +// keeping test mutations aligned with the production synchronization contract. +func registryDeleteLocked(kind domain.BackendKind) { + backendRegistry.mu.Lock() + defer backendRegistry.mu.Unlock() + + delete(backendRegistry.factories, kind) +} + +// registrySetLocked overwrites a factory in the registry under the mutex. +func registrySetLocked(kind domain.BackendKind, factory BackendFactory) { + backendRegistry.mu.Lock() + defer backendRegistry.mu.Unlock() + + backendRegistry.factories[kind] = factory +} + func TestNewBackendFromConfig_NilConfig(t *testing.T) { res, err := NewBackendFromConfig(context.Background(), nil) @@ -142,15 +181,8 @@ func TestNewBackendFromConfig_UnsupportedBackend(t *testing.T) { func TestNewBackendFromConfig_NoFactoryRegistered(t *testing.T) { // Temporarily remove the postgres factory to simulate an unregistered backend. - // Save the original and restore after the test. - original, existed := backendFactories[domain.BackendPostgres] - delete(backendFactories, domain.BackendPostgres) - - t.Cleanup(func() { - if existed { - backendFactories[domain.BackendPostgres] = original - } - }) + withRegistrySnapshot(t) + registryDeleteLocked(domain.BackendPostgres) cfg := &BootstrapConfig{ Backend: domain.BackendPostgres, @@ -168,22 +200,12 @@ func TestNewBackendFromConfig_NoFactoryRegistered(t *testing.T) { func TestNewBackendFromConfig_FactoryReturnsError(t *testing.T) { // Not parallel: mutates global backendFactories for postgres kind. - // Register a factory that always fails, for a custom test backend kind. - // We use postgres kind here with a stub factory. testKind := domain.BackendPostgres - original, existed := backendFactories[testKind] + withRegistrySnapshot(t) expectedErr := fmt.Errorf("simulated connection failure") - backendFactories[testKind] = func(_ context.Context, _ *BootstrapConfig) (*BackendResources, error) { + registrySetLocked(testKind, func(_ context.Context, _ *BootstrapConfig) (*BackendResources, error) { return nil, expectedErr - } - - t.Cleanup(func() { - if existed { - backendFactories[testKind] = original - } else { - delete(backendFactories, testKind) - } }) cfg := &BootstrapConfig{ @@ -203,19 +225,11 @@ func TestNewBackendFromConfig_FactoryReturnsError(t *testing.T) { func TestNewBackendFromConfig_FactoryReturnsResources(t *testing.T) { // Not parallel: mutates global backendFactories for postgres kind. testKind := domain.BackendPostgres - original, existed := backendFactories[testKind] + withRegistrySnapshot(t) expected := &BackendResources{Store: noopStore{}, History: noopHistoryStore{}, ChangeFeed: noopChangeFeed{}, Closer: noopCloser{}} - backendFactories[testKind] = func(_ context.Context, _ *BootstrapConfig) (*BackendResources, error) { + registrySetLocked(testKind, func(_ context.Context, _ *BootstrapConfig) (*BackendResources, error) { return expected, nil - } - - t.Cleanup(func() { - if existed { - backendFactories[testKind] = original - } else { - delete(backendFactories, testKind) - } }) cfg := &BootstrapConfig{ @@ -268,18 +282,10 @@ func TestNewBackendFromConfig_FactoryReturnsIncompleteResources(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { testKind := domain.BackendPostgres - original, existed := backendFactories[testKind] + withRegistrySnapshot(t) - backendFactories[testKind] = func(_ context.Context, _ *BootstrapConfig) (*BackendResources, error) { + registrySetLocked(testKind, func(_ context.Context, _ *BootstrapConfig) (*BackendResources, error) { return tc.resources, nil - } - - t.Cleanup(func() { - if existed { - backendFactories[testKind] = original - } else { - delete(backendFactories, testKind) - } }) cfg := &BootstrapConfig{ @@ -301,21 +307,13 @@ func TestNewBackendFromConfig_FactoryReturnsIncompleteResources(t *testing.T) { func TestNewBackendFromConfig_AppliesDefaults(t *testing.T) { // Not parallel: mutates global backendFactories for postgres kind. testKind := domain.BackendPostgres - original, existed := backendFactories[testKind] + withRegistrySnapshot(t) var capturedCfg *BootstrapConfig - backendFactories[testKind] = func(_ context.Context, cfg *BootstrapConfig) (*BackendResources, error) { + registrySetLocked(testKind, func(_ context.Context, cfg *BootstrapConfig) (*BackendResources, error) { capturedCfg = cfg return &BackendResources{Store: noopStore{}, History: noopHistoryStore{}, ChangeFeed: noopChangeFeed{}, Closer: noopCloser{}}, nil - } - - t.Cleanup(func() { - if existed { - backendFactories[testKind] = original - } else { - delete(backendFactories, testKind) - } }) cfg := &BootstrapConfig{ @@ -339,16 +337,8 @@ func TestNewBackendFromConfig_AppliesDefaults(t *testing.T) { func TestRegisterBackendFactory_RejectsOverwrite(t *testing.T) { testKind := domain.BackendPostgres - original, existed := backendFactories[testKind] - - t.Cleanup(func() { - if existed { - backendFactories[testKind] = original - } else { - delete(backendFactories, testKind) - } - }) - delete(backendFactories, testKind) + withRegistrySnapshot(t) + registryDeleteLocked(testKind) firstCalled := false secondCalled := false @@ -366,7 +356,8 @@ func TestRegisterBackendFactory_RejectsOverwrite(t *testing.T) { require.Error(t, err) assert.ErrorIs(t, err, errBackendAlreadyRegistered) - factory := backendFactories[testKind] + factories, _ := backendRegistry.snapshot() + factory := factories[testKind] require.NotNil(t, factory) _, _ = factory(context.Background(), nil) @@ -377,62 +368,55 @@ func TestRegisterBackendFactory_RejectsOverwrite(t *testing.T) { func TestRegisterBackendFactory_RejectsNilFactory(t *testing.T) { kind := domain.BackendMongoDB - original, existed := backendFactories[kind] - backendFactories[kind] = func(_ context.Context, _ *BootstrapConfig) (*BackendResources, error) { + withRegistrySnapshot(t) + registrySetLocked(kind, func(_ context.Context, _ *BootstrapConfig) (*BackendResources, error) { return &BackendResources{Store: noopStore{}, History: noopHistoryStore{}, ChangeFeed: noopChangeFeed{}, Closer: noopCloser{}}, nil - } - - t.Cleanup(func() { - if existed { - backendFactories[kind] = original - } else { - delete(backendFactories, kind) - } }) err := RegisterBackendFactory(kind, nil) require.Error(t, err) assert.ErrorIs(t, err, errNilBackendFactory) - _, ok := backendFactories[kind] + factories, _ := backendRegistry.snapshot() + _, ok := factories[kind] assert.True(t, ok) } +func TestRecordInitError_NilIsIgnored(t *testing.T) { + withRegistrySnapshot(t) + + RecordInitError(nil) + + _, initErrors := backendRegistry.snapshot() + assert.Empty(t, initErrors, "RecordInitError(nil) must not append to initErrors") +} + // --------------------------------------------------------------------------- // MEDIUM-12: Verify ResetBackendFactories clears state for test isolation // --------------------------------------------------------------------------- func TestResetBackendFactories_ClearsRegistrations(t *testing.T) { // Not parallel: mutates global backendFactories. - // Save the original state. - origFactories := make(map[domain.BackendKind]BackendFactory, len(backendFactories)) - for k, v := range backendFactories { - origFactories[k] = v - } - - origErrors := initErrors - - t.Cleanup(func() { - backendFactories = origFactories - initErrors = origErrors - }) + withRegistrySnapshot(t) // Register a factory and record an init error. - delete(backendFactories, domain.BackendPostgres) + registryDeleteLocked(domain.BackendPostgres) require.NoError(t, RegisterBackendFactory(domain.BackendPostgres, func(_ context.Context, _ *BootstrapConfig) (*BackendResources, error) { return &BackendResources{Store: noopStore{}, History: noopHistoryStore{}, ChangeFeed: noopChangeFeed{}, Closer: noopCloser{}}, nil })) RecordInitError(fmt.Errorf("simulated init error")) - assert.Len(t, backendFactories, 1) + factories, initErrors := backendRegistry.snapshot() + assert.Len(t, factories, 1) assert.NotEmpty(t, initErrors) // Reset and verify. ResetBackendFactories() - assert.Empty(t, backendFactories, "ResetBackendFactories should clear all factories") - assert.Nil(t, initErrors, "ResetBackendFactories should clear init errors") + factories, initErrors = backendRegistry.snapshot() + assert.Empty(t, factories, "ResetBackendFactories should clear all factories") + assert.Empty(t, initErrors, "ResetBackendFactories should clear init errors") } var _ io.Closer = noopCloser{} diff --git a/commons/systemplane/bootstrap/config.go b/commons/systemplane/bootstrap/config.go index 945937c0..e4c194c6 100644 --- a/commons/systemplane/bootstrap/config.go +++ b/commons/systemplane/bootstrap/config.go @@ -3,8 +3,11 @@ package bootstrap import ( + "encoding/base64" "errors" "fmt" + "maps" + "os" "strings" "time" @@ -22,6 +25,7 @@ var ( ErrInvalidWatchMode = errors.New("systemplane: mongodb watch mode must be change_stream or poll") ErrInvalidPollInterval = errors.New("systemplane: mongodb poll interval must be greater than zero when watch mode is poll") ErrInvalidMongoIdentifier = errors.New("systemplane: invalid mongodb identifier") + ErrUnhandledBackend = errors.New("systemplane: unhandled backend kind") ) // BootstrapConfig holds the initial configuration needed to connect to the @@ -40,6 +44,41 @@ type SecretStoreConfig struct { SecretKeys []string } +// String implements fmt.Stringer to prevent accidental master key exposure in logs or spans. +func (s *SecretStoreConfig) String() string { + if s == nil { + return "" + } + + return "SecretStoreConfig{MasterKey:REDACTED}" +} + +// GoString implements fmt.GoStringer to prevent accidental master key exposure in %#v formatting. +func (s *SecretStoreConfig) GoString() string { + return s.String() +} + +// validate checks that the master key meets minimum size requirements. +// It accepts either exactly 32 raw bytes or a valid base64-encoded 32-byte value. +func (s *SecretStoreConfig) validate() error { + if s == nil { + return nil + } + + const masterKeySizeBytes = 32 + + if len([]byte(s.MasterKey)) == masterKeySizeBytes { + return nil + } + + decoded, err := base64.StdEncoding.DecodeString(s.MasterKey) + if err == nil && len(decoded) == masterKeySizeBytes { + return nil + } + + return fmt.Errorf("secret master key must be exactly %d raw bytes or a valid base64-encoded %d-byte value", masterKeySizeBytes, masterKeySizeBytes) +} + // PostgresBootstrapConfig holds PostgreSQL-specific bootstrap settings. type PostgresBootstrapConfig struct { DSN string @@ -62,20 +101,81 @@ type MongoBootstrapConfig struct { ApplyBehaviors map[string]domain.ApplyBehavior } +// ApplyKeyDefs propagates ApplyBehavior from the given KeyDefs into the +// bootstrap configuration and auto-configures secret encryption when secret +// keys are detected and the SYSTEMPLANE_SECRET_MASTER_KEY env var is set. +// +// Reading from the environment here is intentional and consistent with the +// rest of the bootstrap layer (LoadFromEnv, LoadFromEnvOrDefault). Bootstrap +// is the one place where direct os.Getenv calls are expected: the entire +// purpose of this layer is to translate process-environment state into typed +// Go configuration before any backend is created. +// +// This is typically called once during service startup, after +// LoadFromEnvOrDefault and before creating the backend. +func (cfg *BootstrapConfig) ApplyKeyDefs(defs []domain.KeyDef) { + if cfg == nil { + return + } + + applyBehaviors := make(map[string]domain.ApplyBehavior, len(defs)) + secretKeys := make([]string, 0) + + for _, def := range defs { + applyBehaviors[def.Key] = def.ApplyBehavior + + if def.Secret { + secretKeys = append(secretKeys, def.Key) + } + } + + cfg.ApplyBehaviors = applyBehaviors + + if cfg.Postgres != nil { + pgBehaviors := make(map[string]domain.ApplyBehavior, len(applyBehaviors)) + maps.Copy(pgBehaviors, applyBehaviors) + + cfg.Postgres.ApplyBehaviors = pgBehaviors + } + + if cfg.MongoDB != nil { + mgBehaviors := make(map[string]domain.ApplyBehavior, len(applyBehaviors)) + maps.Copy(mgBehaviors, applyBehaviors) + + cfg.MongoDB.ApplyBehaviors = mgBehaviors + } + + masterKey := strings.TrimSpace(os.Getenv(EnvSecretMasterKey)) + if len(secretKeys) == 0 || masterKey == "" { + return + } + + cfg.Secrets = &SecretStoreConfig{ + MasterKey: masterKey, + SecretKeys: secretKeys, + } +} + // Validate checks that the bootstrap configuration is well-formed. func (cfg *BootstrapConfig) Validate() error { if cfg == nil || !cfg.Backend.IsValid() { return fmt.Errorf("%w: %q", ErrMissingBackend, backendString(cfg)) } + if cfg.Secrets != nil { + if err := cfg.Secrets.validate(); err != nil { + return err + } + } + switch cfg.Backend { case domain.BackendPostgres: return validatePostgresBootstrap(cfg.Postgres) case domain.BackendMongoDB: return validateMongoBootstrap(cfg.MongoDB) + default: + return fmt.Errorf("validate: %w: %q", ErrUnhandledBackend, cfg.Backend) } - - return nil } // ApplyDefaults fills in zero-value fields with sensible defaults. diff --git a/commons/systemplane/bootstrap/env.go b/commons/systemplane/bootstrap/env.go index 894c59ca..0f2ff6cf 100644 --- a/commons/systemplane/bootstrap/env.go +++ b/commons/systemplane/bootstrap/env.go @@ -31,6 +31,46 @@ const ( EnvMongoPollIntervalSec = "SYSTEMPLANE_MONGODB_POLL_INTERVAL_SEC" ) +// EnvSecretMasterKey is the environment variable holding the AES-256 master key +// for encrypting/decrypting secret configuration values at rest. +const EnvSecretMasterKey = "SYSTEMPLANE_SECRET_MASTER_KEY" + +// LoadFromEnvOrDefault reads SYSTEMPLANE_* environment variables when present, +// otherwise falls back to a minimal Postgres backend configuration using the +// provided DSN. This covers the common case where a product embeds systemplane +// into its existing Postgres database and does not set dedicated SYSTEMPLANE_* +// env vars. +// +// When SYSTEMPLANE_BACKEND is set, this delegates entirely to LoadFromEnv. +// When it is not set, a BootstrapConfig is constructed with +// Backend=BackendPostgres and the fallbackDSN as the Postgres DSN. +// ApplyDefaults and Validate are called in both paths. +func LoadFromEnvOrDefault(fallbackDSN string) (*BootstrapConfig, error) { + if strings.TrimSpace(os.Getenv(EnvBackend)) != "" { + return LoadFromEnv() + } + + trimmedDSN := strings.TrimSpace(fallbackDSN) + if trimmedDSN == "" { + return nil, ErrMissingPostgresDSN + } + + cfg := &BootstrapConfig{ + Backend: domain.BackendPostgres, + Postgres: &PostgresBootstrapConfig{ + DSN: trimmedDSN, + }, + } + + cfg.ApplyDefaults() + + if err := cfg.Validate(); err != nil { + return nil, err + } + + return cfg, nil +} + // LoadFromEnv reads SYSTEMPLANE_* environment variables and returns a validated // BootstrapConfig. func LoadFromEnv() (*BootstrapConfig, error) { diff --git a/commons/systemplane/bootstrap/env_or_default_test.go b/commons/systemplane/bootstrap/env_or_default_test.go new file mode 100644 index 00000000..74edd33c --- /dev/null +++ b/commons/systemplane/bootstrap/env_or_default_test.go @@ -0,0 +1,57 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package bootstrap + +import ( + "testing" + + "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadFromEnvOrDefault_FallbackDSN(t *testing.T) { + t.Setenv(EnvBackend, "") + // No SYSTEMPLANE_BACKEND set — should use fallback DSN. + cfg, err := LoadFromEnvOrDefault("postgres://user:pass@localhost:5432/mydb") + require.NoError(t, err) + assert.Equal(t, domain.BackendPostgres, cfg.Backend) + assert.NotNil(t, cfg.Postgres) + assert.Equal(t, "postgres://user:pass@localhost:5432/mydb", cfg.Postgres.DSN) + // Defaults applied. + assert.Equal(t, DefaultPostgresSchema, cfg.Postgres.Schema) + assert.Equal(t, DefaultPostgresEntriesTable, cfg.Postgres.EntriesTable) +} + +func TestLoadFromEnvOrDefault_DelegatesWhenEnvSet(t *testing.T) { + t.Setenv(EnvBackend, "postgres") + t.Setenv(EnvPostgresDSN, "postgres://env@host:5432/envdb") + + cfg, err := LoadFromEnvOrDefault("postgres://fallback@host:5432/fallbackdb") + require.NoError(t, err) + // Should use the env DSN, not the fallback. + assert.Equal(t, "postgres://env@host:5432/envdb", cfg.Postgres.DSN) +} + +func TestLoadFromEnvOrDefault_EmptyFallbackDSN(t *testing.T) { + t.Setenv(EnvBackend, "") + _, err := LoadFromEnvOrDefault("") + require.Error(t, err) + assert.ErrorIs(t, err, ErrMissingPostgresDSN) +} + +func TestLoadFromEnvOrDefault_WhitespaceOnlyFallbackDSN(t *testing.T) { + t.Setenv(EnvBackend, "") + _, err := LoadFromEnvOrDefault(" ") + require.Error(t, err) + assert.ErrorIs(t, err, ErrMissingPostgresDSN) +} + +func TestLoadFromEnvOrDefault_TrimsFallbackDSN(t *testing.T) { + t.Setenv(EnvBackend, "") + cfg, err := LoadFromEnvOrDefault(" postgres://user:pass@localhost:5432/db ") + require.NoError(t, err) + assert.Equal(t, "postgres://user:pass@localhost:5432/db", cfg.Postgres.DSN) +} diff --git a/commons/systemplane/catalog/doc.go b/commons/systemplane/catalog/doc.go new file mode 100644 index 00000000..2345a6e9 --- /dev/null +++ b/commons/systemplane/catalog/doc.go @@ -0,0 +1,7 @@ +// Copyright 2025 Lerian Studio. + +// Package catalog defines canonical names, tiers, and components for +// configuration keys shared across Lerian products. Products validate +// their local KeyDefs against this catalog to prevent naming and +// classification drift. +package catalog diff --git a/commons/systemplane/catalog/keys_postgres.go b/commons/systemplane/catalog/keys_postgres.go new file mode 100644 index 00000000..dd38a05b --- /dev/null +++ b/commons/systemplane/catalog/keys_postgres.go @@ -0,0 +1,39 @@ +// Copyright 2025 Lerian Studio. + +package catalog + +import "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + +var postgresKeys = []SharedKey{ + // Primary connection — BundleRebuild, component: "postgres" + {Key: "postgres.primary_host", EnvVar: "POSTGRES_HOST", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.primary", Description: "PostgreSQL primary host"}, + {Key: "postgres.primary_port", EnvVar: "POSTGRES_PORT", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.primary", Description: "PostgreSQL primary port"}, + {Key: "postgres.primary_user", EnvVar: "POSTGRES_USER", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.primary", Description: "PostgreSQL primary user"}, + {Key: "postgres.primary_password", EnvVar: "POSTGRES_PASSWORD", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.primary", Secret: true, RedactPolicy: domain.RedactFull, Description: "PostgreSQL primary password"}, + {Key: "postgres.primary_db", EnvVar: "POSTGRES_DB", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.primary", Description: "PostgreSQL primary database name"}, + {Key: "postgres.primary_ssl_mode", EnvVar: "POSTGRES_SSLMODE", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.primary", Description: "PostgreSQL primary SSL mode"}, + + // Replica connection — BundleRebuild, component: "postgres" + {Key: "postgres.replica_host", EnvVar: "POSTGRES_REPLICA_HOST", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.replica", Description: "PostgreSQL replica host"}, + {Key: "postgres.replica_port", EnvVar: "POSTGRES_REPLICA_PORT", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.replica", Description: "PostgreSQL replica port"}, + {Key: "postgres.replica_user", EnvVar: "POSTGRES_REPLICA_USER", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.replica", Description: "PostgreSQL replica user"}, + {Key: "postgres.replica_password", EnvVar: "POSTGRES_REPLICA_PASSWORD", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.replica", Secret: true, RedactPolicy: domain.RedactFull, Description: "PostgreSQL replica password"}, + {Key: "postgres.replica_db", EnvVar: "POSTGRES_REPLICA_DB", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.replica", Description: "PostgreSQL replica database name"}, + {Key: "postgres.replica_ssl_mode", EnvVar: "POSTGRES_REPLICA_SSLMODE", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.replica", Description: "PostgreSQL replica SSL mode"}, + + // Pool tuning — CANONICAL: LiveRead (NOT BundleRebuild) + // Go's database/sql supports SetMaxOpenConns() etc. at runtime without teardown. + {Key: "postgres.max_open_conns", EnvVar: "POSTGRES_MAX_OPEN_CONNS", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "postgres.pool", Description: "Maximum open database connections"}, + {Key: "postgres.max_idle_conns", EnvVar: "POSTGRES_MAX_IDLE_CONNS", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "postgres.pool", Description: "Maximum idle database connections"}, + {Key: "postgres.conn_max_lifetime_mins", EnvVar: "POSTGRES_CONN_MAX_LIFETIME_MINS", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "postgres.pool", Description: "Maximum connection lifetime in minutes"}, + {Key: "postgres.conn_max_idle_time_mins", EnvVar: "POSTGRES_CONN_MAX_IDLE_TIME_MINS", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "postgres.pool", Description: "Maximum idle time in minutes"}, + + // Connection bootstrap/rebuild tuning + {Key: "postgres.connect_timeout_sec", EnvVar: "POSTGRES_CONNECT_TIMEOUT_SEC", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres", Group: "postgres.pool", Description: "Connection timeout in seconds"}, + {Key: "postgres.migrations_path", EnvVar: "MIGRATIONS_PATH", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "postgres", Description: "Database migrations directory path"}, +} + +// PostgresKeys returns canonical keys for PostgreSQL configuration. +func PostgresKeys() []SharedKey { + return cloneSharedKeys(postgresKeys) +} diff --git a/commons/systemplane/catalog/keys_redis.go b/commons/systemplane/catalog/keys_redis.go new file mode 100644 index 00000000..bdb232ed --- /dev/null +++ b/commons/systemplane/catalog/keys_redis.go @@ -0,0 +1,28 @@ +// Copyright 2025 Lerian Studio. + +package catalog + +import "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + +var redisKeys = []SharedKey{ + // Connection — BundleRebuild, component: "redis" + {Key: "redis.host", EnvVar: "REDIS_HOST", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Group: "redis.connection", Description: "Redis host address"}, + {Key: "redis.master_name", EnvVar: "REDIS_MASTER_NAME", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Group: "redis.connection", Description: "Redis Sentinel master name"}, + {Key: "redis.password", EnvVar: "REDIS_PASSWORD", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Group: "redis.connection", Secret: true, RedactPolicy: domain.RedactFull, Description: "Redis password"}, + {Key: "redis.db", EnvVar: "REDIS_DB", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Group: "redis.connection", Description: "Redis database index"}, + {Key: "redis.protocol", EnvVar: "REDIS_PROTOCOL", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Group: "redis.connection", Description: "Redis protocol version"}, + {Key: "redis.tls", EnvVar: "REDIS_TLS", ValueType: domain.ValueTypeBool, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Group: "redis.connection", Description: "Enable Redis TLS"}, + {Key: "redis.ca_cert", EnvVar: "REDIS_CA_CERT", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Group: "redis.connection", Secret: true, RedactPolicy: domain.RedactFull, Description: "Redis CA certificate"}, + // Pool + {Key: "redis.pool_size", EnvVar: "REDIS_POOL_SIZE", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Group: "redis.pool", Description: "Redis connection pool size"}, + {Key: "redis.min_idle_conns", EnvVar: "REDIS_MIN_IDLE_CONNS", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Group: "redis.pool", Description: "Redis minimum idle connections"}, + // Timeouts — BundleRebuild (affect client config) + {Key: "redis.read_timeout_ms", EnvVar: "REDIS_READ_TIMEOUT_MS", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Group: "redis.timeouts", Description: "Redis read timeout in milliseconds"}, + {Key: "redis.write_timeout_ms", EnvVar: "REDIS_WRITE_TIMEOUT_MS", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Group: "redis.timeouts", Description: "Redis write timeout in milliseconds"}, + {Key: "redis.dial_timeout_ms", EnvVar: "REDIS_DIAL_TIMEOUT_MS", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Group: "redis.timeouts", Description: "Redis dial timeout in milliseconds"}, +} + +// RedisKeys returns canonical keys for Redis configuration. +func RedisKeys() []SharedKey { + return cloneSharedKeys(redisKeys) +} diff --git a/commons/systemplane/catalog/keys_shared.go b/commons/systemplane/catalog/keys_shared.go new file mode 100644 index 00000000..561a2b02 --- /dev/null +++ b/commons/systemplane/catalog/keys_shared.go @@ -0,0 +1,119 @@ +// Copyright 2025 Lerian Studio. + +package catalog + +import "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + +var corsKeys = []SharedKey{ + {Key: "cors.allowed_origins", EnvVar: "CORS_ALLOWED_ORIGINS", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "cors", Description: "CORS allowed origins (comma-separated)"}, + {Key: "cors.allowed_methods", EnvVar: "CORS_ALLOWED_METHODS", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "cors", Description: "CORS allowed methods (comma-separated)"}, + {Key: "cors.allowed_headers", EnvVar: "CORS_ALLOWED_HEADERS", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "cors", Description: "CORS allowed headers (comma-separated)"}, +} + +var appServerKeys = []SharedKey{ + {Key: "app.env_name", EnvVar: "ENV_NAME", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "app", Description: "Environment name"}, + // app.log_level is declared with ComponentNone because the catalog + // represents the minimal canonical definition. Products that use + // ComponentDiff to trigger hot-reload of their logger (e.g., zap + // AtomicLevel) MAY override Component to "logger" in their KeyDefs. + // This is an expected, documented deviation — use + // WithKnownDeviation("app.log_level", "Component") in catalog tests. + {Key: "app.log_level", EnvVar: "LOG_LEVEL", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "app", Description: "Application log level"}, + {Key: "server.address", EnvVar: "SERVER_ADDRESS", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "server", Description: "HTTP server listen address"}, + {Key: "server.body_limit_bytes", EnvVar: "HTTP_BODY_LIMIT_BYTES", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "server", Description: "HTTP request body size limit in bytes"}, + {Key: "server.tls_cert_file", EnvVar: "SERVER_TLS_CERT_FILE", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "server.tls", Description: "TLS certificate file path"}, + {Key: "server.tls_key_file", EnvVar: "SERVER_TLS_KEY_FILE", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "server.tls", Description: "TLS key file path"}, +} + +var rateLimitKeys = []SharedKey{ + {Key: "rate_limit.enabled", EnvVar: "RATE_LIMIT_ENABLED", ValueType: domain.ValueTypeBool, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "rate_limit", Description: "Enable rate limiting"}, + {Key: "rate_limit.max", EnvVar: "RATE_LIMIT_MAX", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "rate_limit", Description: "Maximum requests per window"}, + {Key: "rate_limit.expiry_sec", EnvVar: "RATE_LIMIT_EXPIRY_SEC", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "rate_limit", Description: "Rate limit window duration in seconds"}, +} + +// AuthKeys defines canonical auth keys. +// EnvVar is intentionally empty because products and plugins use different +// environment variable conventions for the same canonical systemplane keys. +var authKeys = []SharedKey{ + {Key: "auth.enabled", EnvVar: "", MatchEnvVars: []string{"AUTH_ENABLED", "PLUGIN_AUTH_ENABLED"}, ValueType: domain.ValueTypeBool, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "auth", Description: "Enable authentication middleware"}, + {Key: "auth.address", EnvVar: "", MatchEnvVars: []string{"AUTH_ADDRESS", "PLUGIN_AUTH_ADDRESS"}, ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "auth", Description: "Auth service address"}, + {Key: "auth.client_id", EnvVar: "", MatchEnvVars: []string{"AUTH_CLIENT_ID", "PLUGIN_AUTH_CLIENT_ID"}, ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "auth", Description: "Auth client ID"}, + {Key: "auth.client_secret", EnvVar: "", MatchEnvVars: []string{"AUTH_CLIENT_SECRET", "PLUGIN_AUTH_CLIENT_SECRET"}, ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "auth", Secret: true, RedactPolicy: domain.RedactFull, Description: "Auth client secret"}, + {Key: "auth.cache_ttl_sec", EnvVar: "", MatchEnvVars: []string{"AUTH_CACHE_TTL_SEC"}, ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "auth", Description: "Auth cache TTL in seconds"}, +} + +var telemetryKeys = []SharedKey{ + {Key: "telemetry.enabled", EnvVar: "ENABLE_TELEMETRY", ValueType: domain.ValueTypeBool, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "telemetry", Description: "Enable OpenTelemetry"}, + {Key: "telemetry.service_name", EnvVar: "OTEL_RESOURCE_SERVICE_NAME", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "telemetry", Description: "OTEL service name"}, + {Key: "telemetry.library_name", EnvVar: "OTEL_LIBRARY_NAME", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "telemetry", Description: "OTEL library name"}, + {Key: "telemetry.service_version", EnvVar: "OTEL_RESOURCE_SERVICE_VERSION", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "telemetry", Description: "OTEL service version"}, + {Key: "telemetry.deployment_env", EnvVar: "OTEL_RESOURCE_DEPLOYMENT_ENVIRONMENT", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "telemetry", Description: "OTEL deployment environment"}, + {Key: "telemetry.collector_endpoint", EnvVar: "OTEL_EXPORTER_OTLP_ENDPOINT", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, MutableAtRuntime: false, Component: domain.ComponentNone, Group: "telemetry", Description: "OTEL collector endpoint"}, +} + +var rabbitMQKeys = []SharedKey{ + {Key: "rabbitmq.enabled", EnvVar: "RABBITMQ_ENABLED", ValueType: domain.ValueTypeBool, ApplyBehavior: domain.ApplyBundleRebuildAndReconcile, MutableAtRuntime: true, Component: "rabbitmq", Group: "rabbitmq", Description: "Enable RabbitMQ integration"}, + {Key: "rabbitmq.url", EnvVar: "RABBITMQ_URL", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "rabbitmq", Group: "rabbitmq.connection", Secret: true, RedactPolicy: domain.RedactFull, Description: "RabbitMQ connection URL"}, + {Key: "rabbitmq.exchange", EnvVar: "RABBITMQ_EXCHANGE", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "rabbitmq", Group: "rabbitmq.connection", Description: "RabbitMQ exchange name"}, + // LiveRead messaging params + {Key: "rabbitmq.routing_key_prefix", EnvVar: "RABBITMQ_ROUTING_KEY_PREFIX", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "rabbitmq.messaging", Description: "RabbitMQ routing key prefix"}, + {Key: "rabbitmq.publish_timeout_ms", EnvVar: "RABBITMQ_PUBLISH_TIMEOUT_MS", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "rabbitmq.messaging", Description: "RabbitMQ publish timeout in milliseconds"}, + {Key: "rabbitmq.max_retries", EnvVar: "RABBITMQ_MAX_RETRIES", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "rabbitmq.messaging", Description: "RabbitMQ maximum publish retries"}, + {Key: "rabbitmq.retry_backoff_ms", EnvVar: "RABBITMQ_RETRY_BACKOFF_MS", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "rabbitmq.messaging", Description: "RabbitMQ retry backoff in milliseconds"}, + {Key: "rabbitmq.event_signing_secret", EnvVar: "RABBITMQ_EVENT_SIGNING_SECRET", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone, Group: "rabbitmq.messaging", Secret: true, RedactPolicy: domain.RedactFull, Description: "RabbitMQ event signing secret"}, +} + +// CORSKeys returns canonical keys for CORS configuration. +func CORSKeys() []SharedKey { + return cloneSharedKeys(corsKeys) +} + +// AppServerKeys returns canonical keys for application and HTTP server config. +func AppServerKeys() []SharedKey { + return cloneSharedKeys(appServerKeys) +} + +// RateLimitKeys returns canonical keys for rate limiting. +func RateLimitKeys() []SharedKey { + return cloneSharedKeys(rateLimitKeys) +} + +// AuthKeys returns canonical auth keys. +func AuthKeys() []SharedKey { + return cloneSharedKeys(authKeys) +} + +// TelemetryKeys returns canonical keys for OpenTelemetry configuration. +func TelemetryKeys() []SharedKey { + return cloneSharedKeys(telemetryKeys) +} + +// RabbitMQKeys returns canonical keys for RabbitMQ configuration. +func RabbitMQKeys() []SharedKey { + return cloneSharedKeys(rabbitMQKeys) +} + +// AllSharedKeys returns all canonical keys from all categories. +func AllSharedKeys() []SharedKey { + all := make([]SharedKey, 0, + len(postgresKeys)+ + len(redisKeys)+ + len(corsKeys)+ + len(appServerKeys)+ + len(rateLimitKeys)+ + len(authKeys)+ + len(telemetryKeys)+ + len(rabbitMQKeys), + ) + + all = append(all, postgresKeys...) + all = append(all, redisKeys...) + all = append(all, corsKeys...) + all = append(all, appServerKeys...) + all = append(all, rateLimitKeys...) + all = append(all, authKeys...) + all = append(all, telemetryKeys...) + all = append(all, rabbitMQKeys...) + + return cloneSharedKeys(all) +} diff --git a/commons/systemplane/catalog/shared_key.go b/commons/systemplane/catalog/shared_key.go new file mode 100644 index 00000000..ed5243a0 --- /dev/null +++ b/commons/systemplane/catalog/shared_key.go @@ -0,0 +1,38 @@ +// Copyright 2025 Lerian Studio. + +package catalog + +import "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + +// SharedKey defines the canonical metadata for a cross-product config key. +// Products that register a key matching this catalog entry MUST use the +// same Key name, EnvVar, ApplyBehavior, Component, MutableAtRuntime, +// Secret, and RedactPolicy values. +type SharedKey struct { + Key string + EnvVar string // canonical env var (empty = varies by product) + MatchEnvVars []string + ValueType domain.ValueType + ApplyBehavior domain.ApplyBehavior + MutableAtRuntime bool + Component string + Group string + Secret bool + RedactPolicy domain.RedactPolicy + Description string +} + +func cloneSharedKeys(keys []SharedKey) []SharedKey { + if keys == nil { + return nil + } + + cloned := make([]SharedKey, len(keys)) + copy(cloned, keys) + + for i := range cloned { + cloned[i].MatchEnvVars = append([]string(nil), cloned[i].MatchEnvVars...) + } + + return cloned +} diff --git a/commons/systemplane/catalog/validate.go b/commons/systemplane/catalog/validate.go new file mode 100644 index 00000000..9caeb127 --- /dev/null +++ b/commons/systemplane/catalog/validate.go @@ -0,0 +1,250 @@ +// Copyright 2025 Lerian Studio. + +package catalog + +import ( + "fmt" + "slices" + "sort" + "strconv" + "strings" + + "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" +) + +// Mismatch describes a divergence between a product's KeyDef and the catalog. +type Mismatch struct { + CatalogKey string // canonical key name + ProductKey string // product's key name (same if names match) + Field string // "Key", "EnvVar", "ApplyBehavior", "Component", "MutableAtRuntime", "ValueType", "Secret", "RedactPolicy" + CatalogValue string + ProductValue string +} + +// String returns a human-readable description of the mismatch. +func (m Mismatch) String() string { + if m.CatalogKey != m.ProductKey { + return fmt.Sprintf("key %q (catalog: %q): %s: want %s, got %s", + m.ProductKey, m.CatalogKey, m.Field, m.CatalogValue, m.ProductValue) + } + + return fmt.Sprintf("key %q: %s: want %s, got %s", + m.CatalogKey, m.Field, m.CatalogValue, m.ProductValue) +} + +// ValidateOption configures the behavior of ValidateKeyDefs. +type ValidateOption func(*validateConfig) + +type validateConfig struct { + ignoreFields map[string]bool + knownDeviations map[string]map[string]bool // key → field → true +} + +// WithIgnoreFields tells ValidateKeyDefsWithOptions to skip comparison of the +// given fields for all keys. Common usage: WithIgnoreFields("EnvVar") when the +// product does not set EnvVar on its KeyDefs because values come from the +// systemplane store rather than environment variables. +func WithIgnoreFields(fields ...string) ValidateOption { + return func(vc *validateConfig) { + for _, field := range fields { + vc.ignoreFields[field] = true + } + } +} + +// WithKnownDeviation tells ValidateKeyDefsWithOptions to skip a specific +// (catalogKey, field) pair. Use this for intentional, documented deviations +// such as overriding a key's Component for product-specific ComponentDiff +// behavior. +func WithKnownDeviation(catalogKey, field string) ValidateOption { + return func(vc *validateConfig) { + if vc.knownDeviations[catalogKey] == nil { + vc.knownDeviations[catalogKey] = make(map[string]bool) + } + + vc.knownDeviations[catalogKey][field] = true + } +} + +func newValidateConfig(opts []ValidateOption) *validateConfig { + vc := &validateConfig{ + ignoreFields: make(map[string]bool), + knownDeviations: make(map[string]map[string]bool), + } + + for _, opt := range opts { + if opt != nil { + opt(vc) + } + } + + return vc +} + +func (vc *validateConfig) shouldSkip(catalogKey, field string) bool { + if vc.ignoreFields[field] { + return true + } + + if fields, ok := vc.knownDeviations[catalogKey]; ok && fields[field] { + return true + } + + return false +} + +// ValidateKeyDefs checks product KeyDefs against the canonical catalog. +// Matching is by exact Key name first, then by EnvVar or MatchEnvVars when the +// product key name differs but still points to the same canonical configuration +// concept. +// For each match, compares Key, EnvVar, ApplyBehavior, Component, +// MutableAtRuntime, ValueType, Secret, and RedactPolicy. +// +// Does NOT flag keys that exist in the product but not in the catalog +// (those are product-specific keys, which is fine). +// +// catalogKeys accepts variadic slices so callers can pass individual categories +// or AllSharedKeys(). To suppress known deviations, use +// ValidateKeyDefsWithOptions which accepts ValidateOption values. +// ValidateKeyDefs retains its original signature for backward compatibility. +func ValidateKeyDefs(productDefs []domain.KeyDef, catalogKeys ...[]SharedKey) []Mismatch { + return ValidateKeyDefsWithOptions(productDefs, catalogKeys, nil) +} + +// ValidateKeyDefsWithOptions is the full-featured variant of ValidateKeyDefs +// that accepts filtering options. Use WithIgnoreFields and WithKnownDeviation +// to suppress expected mismatches in product catalog tests. +func ValidateKeyDefsWithOptions(productDefs []domain.KeyDef, catalogKeys [][]SharedKey, opts []ValidateOption) []Mismatch { + keyIndex, envIndex := buildCatalogIndexes(catalogKeys...) + vc := newValidateConfig(opts) + + var mismatches []Mismatch + + for _, pd := range productDefs { + sk, found, matchedByEnv := resolveSharedKey(pd, keyIndex, envIndex) + + if !found { + continue // product-specific key — not in catalog, nothing to check + } + + for _, mm := range compareKeyDef(pd, sk, matchedByEnv) { + if vc.shouldSkip(mm.CatalogKey, mm.Field) { + continue + } + + mismatches = append(mismatches, mm) + } + } + + sort.Slice(mismatches, func(i, j int) bool { + if mismatches[i].CatalogKey != mismatches[j].CatalogKey { + return mismatches[i].CatalogKey < mismatches[j].CatalogKey + } + + return mismatches[i].Field < mismatches[j].Field + }) + + return mismatches +} + +func buildCatalogIndexes(catalogKeys ...[]SharedKey) (map[string]SharedKey, map[string]SharedKey) { + keyIndex := make(map[string]SharedKey) + envIndex := make(map[string]SharedKey) + + for _, slice := range catalogKeys { + for _, sk := range slice { + keyIndex[sk.Key] = sk + for _, envVar := range allowedEnvVars(sk) { + envIndex[envVar] = sk + } + } + } + + return keyIndex, envIndex +} + +func resolveSharedKey(pd domain.KeyDef, keyIndex map[string]SharedKey, envIndex map[string]SharedKey) (SharedKey, bool, bool) { + if sk, found := keyIndex[pd.Key]; found { + return sk, true, false + } + + if pd.EnvVar == "" { + return SharedKey{}, false, false + } + + sk, found := envIndex[pd.EnvVar] + + return sk, found, found +} + +func compareKeyDef(pd domain.KeyDef, sk SharedKey, matchedByEnv bool) []Mismatch { + comparisons := []Mismatch{ + mismatchForString(pd, sk, "ApplyBehavior", string(sk.ApplyBehavior), string(pd.ApplyBehavior)), + mismatchForString(pd, sk, "Component", sk.Component, pd.Component), + mismatchForString(pd, sk, "MutableAtRuntime", strconv.FormatBool(sk.MutableAtRuntime), strconv.FormatBool(pd.MutableAtRuntime)), + mismatchForString(pd, sk, "ValueType", string(sk.ValueType), string(pd.ValueType)), + mismatchForString(pd, sk, "Secret", strconv.FormatBool(sk.Secret), strconv.FormatBool(pd.Secret)), + mismatchForString(pd, sk, "RedactPolicy", string(normalizeRedactPolicy(sk.RedactPolicy, sk.Secret)), string(normalizeRedactPolicy(pd.RedactPolicy, pd.Secret))), + } + + if matchedByEnv && pd.Key != sk.Key { + comparisons = append(comparisons, mismatchForString(pd, sk, "Key", sk.Key, pd.Key)) + } + + if expectedEnvVars := allowedEnvVars(sk); len(expectedEnvVars) > 0 { + if !slices.Contains(expectedEnvVars, pd.EnvVar) { + comparisons = append(comparisons, mismatchForString(pd, sk, "EnvVar", strings.Join(expectedEnvVars, "|"), pd.EnvVar)) + } + } else if pd.EnvVar != "" { + // Catalog defines no env vars for this key but product sets one — flag it. + comparisons = append(comparisons, mismatchForString(pd, sk, "EnvVar", "", pd.EnvVar)) + } + + mismatches := make([]Mismatch, 0, len(comparisons)) + for _, comparison := range comparisons { + if comparison.Field == "" { + continue + } + + mismatches = append(mismatches, comparison) + } + + return mismatches +} + +func mismatchForString(pd domain.KeyDef, sk SharedKey, field, catalogValue, productValue string) Mismatch { + if catalogValue == productValue { + return Mismatch{} + } + + return Mismatch{ + CatalogKey: sk.Key, + ProductKey: pd.Key, + Field: field, + CatalogValue: catalogValue, + ProductValue: productValue, + } +} + +func allowedEnvVars(sk SharedKey) []string { + allowed := make([]string, 0, 1+len(sk.MatchEnvVars)) + if sk.EnvVar != "" { + allowed = append(allowed, sk.EnvVar) + } + + allowed = append(allowed, sk.MatchEnvVars...) + + return allowed +} + +func normalizeRedactPolicy(policy domain.RedactPolicy, secret bool) domain.RedactPolicy { + if secret { + return domain.RedactFull + } + + if policy == "" { + return domain.RedactNone + } + + return policy +} diff --git a/commons/systemplane/catalog/validate_options_test.go b/commons/systemplane/catalog/validate_options_test.go new file mode 100644 index 00000000..561ea412 --- /dev/null +++ b/commons/systemplane/catalog/validate_options_test.go @@ -0,0 +1,148 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package catalog + +import ( + "testing" + + "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateKeyDefsWithOptions_WithIgnoreFields(t *testing.T) { + t.Parallel() + + productDefs := []domain.KeyDef{ + { + Key: "app.log_level", ValueType: domain.ValueTypeString, + ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, + Component: domain.ComponentNone, + // EnvVar intentionally empty — would normally mismatch the catalog's "LOG_LEVEL". + }, + } + + // Without options: should produce an EnvVar mismatch. + raw := ValidateKeyDefs(productDefs, AppServerKeys()) + envVarMismatches := filterByField(raw, "EnvVar") + require.NotEmpty(t, envVarMismatches, "expected EnvVar mismatch without options") + + // With WithIgnoreFields("EnvVar"): should suppress the EnvVar mismatch. + filtered := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, + []ValidateOption{WithIgnoreFields("EnvVar")}) + envVarMismatches = filterByField(filtered, "EnvVar") + assert.Empty(t, envVarMismatches, "EnvVar mismatches should be suppressed by WithIgnoreFields") +} + +func TestValidateKeyDefsWithOptions_WithKnownDeviation(t *testing.T) { + t.Parallel() + + productDefs := []domain.KeyDef{ + { + Key: "app.log_level", ValueType: domain.ValueTypeString, + ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, + Component: "logger", // intentional deviation from catalog's ComponentNone + EnvVar: "LOG_LEVEL", + }, + } + + // Without options: should produce a Component mismatch. + raw := ValidateKeyDefs(productDefs, AppServerKeys()) + componentMismatches := filterByField(raw, "Component") + require.NotEmpty(t, componentMismatches, "expected Component mismatch without options") + + // With WithKnownDeviation: should suppress only that key+field. + filtered := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, + []ValidateOption{WithKnownDeviation("app.log_level", "Component")}) + componentMismatches = filterByField(filtered, "Component") + assert.Empty(t, componentMismatches, "Component mismatch should be suppressed by WithKnownDeviation") +} + +func TestValidateKeyDefsWithOptions_KnownDeviationDoesNotSuppressOtherKeys(t *testing.T) { + t.Parallel() + + productDefs := []domain.KeyDef{ + { + Key: "app.log_level", ValueType: domain.ValueTypeString, + ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, + Component: "logger", EnvVar: "LOG_LEVEL", + }, + { + Key: "cors.allowed_origins", ValueType: domain.ValueTypeString, + ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, + Component: "wrong-component", EnvVar: "CORS_ALLOWED_ORIGINS", + }, + } + + filtered := ValidateKeyDefsWithOptions(productDefs, + [][]SharedKey{AppServerKeys(), CORSKeys()}, + []ValidateOption{WithKnownDeviation("app.log_level", "Component")}) + + // app.log_level Component should be suppressed. + for _, mm := range filtered { + if mm.CatalogKey == "app.log_level" && mm.Field == "Component" { + t.Error("app.log_level Component deviation should have been suppressed") + } + } + + // cors.allowed_origins Component should NOT be suppressed. + corsComponentFound := false + for _, mm := range filtered { + if mm.CatalogKey == "cors.allowed_origins" && mm.Field == "Component" { + corsComponentFound = true + } + } + + assert.True(t, corsComponentFound, "cors.allowed_origins Component mismatch should NOT be suppressed") +} + +func TestValidateKeyDefsWithOptions_NilOptions(t *testing.T) { + t.Parallel() + + productDefs := []domain.KeyDef{ + { + Key: "app.log_level", ValueType: domain.ValueTypeString, + ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, + Component: domain.ComponentNone, EnvVar: "LOG_LEVEL", + }, + } + + // Nil options should work the same as no options. + result := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, nil) + assert.Empty(t, result, "matching key with nil options should produce no mismatches") +} + +func TestValidateKeyDefsWithOptions_CombinedOptions(t *testing.T) { + t.Parallel() + + productDefs := []domain.KeyDef{ + { + Key: "app.log_level", ValueType: domain.ValueTypeString, + ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, + Component: "logger", + // EnvVar empty — both Component and EnvVar would mismatch. + }, + } + + filtered := ValidateKeyDefsWithOptions(productDefs, [][]SharedKey{AppServerKeys()}, + []ValidateOption{ + WithIgnoreFields("EnvVar"), + WithKnownDeviation("app.log_level", "Component"), + }) + + assert.Empty(t, filtered, "all mismatches should be suppressed by combined options") +} + +func filterByField(mismatches []Mismatch, field string) []Mismatch { + var result []Mismatch + + for _, mm := range mismatches { + if mm.Field == field { + result = append(result, mm) + } + } + + return result +} diff --git a/commons/systemplane/catalog/validate_test.go b/commons/systemplane/catalog/validate_test.go new file mode 100644 index 00000000..08046fd0 --- /dev/null +++ b/commons/systemplane/catalog/validate_test.go @@ -0,0 +1,362 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package catalog + +import ( + "testing" + + "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// helper builds a minimal valid KeyDef from a SharedKey so tests only need to +// override the field under test. +func keyDefFromShared(sk SharedKey) domain.KeyDef { + return domain.KeyDef{ + Key: sk.Key, + EnvVar: sk.EnvVar, + Kind: domain.KindConfig, + AllowedScopes: []domain.Scope{domain.ScopeGlobal}, + ValueType: sk.ValueType, + Secret: sk.Secret, + RedactPolicy: sk.RedactPolicy, + ApplyBehavior: sk.ApplyBehavior, + MutableAtRuntime: sk.MutableAtRuntime, + Component: sk.Component, + Group: sk.Group, + Description: sk.Description, + } +} + +func TestValidateKeyDefs_AllMatch(t *testing.T) { + // Build product defs that exactly mirror a subset of the catalog. + catalogSlice := []SharedKey{ + {Key: "postgres.primary_host", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres"}, + {Key: "redis.host", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis"}, + } + + productDefs := make([]domain.KeyDef, len(catalogSlice)) + for i, sk := range catalogSlice { + productDefs[i] = keyDefFromShared(sk) + } + + mismatches := ValidateKeyDefs(productDefs, catalogSlice) + assert.Empty(t, mismatches, "expected zero mismatches when product matches catalog exactly") +} + +func TestValidateKeyDefs_ApplyBehaviorMismatch(t *testing.T) { + catalogSlice := []SharedKey{ + {Key: "postgres.max_open_conns", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone}, + } + + pd := keyDefFromShared(catalogSlice[0]) + pd.ApplyBehavior = domain.ApplyBundleRebuild // drift! + + mismatches := ValidateKeyDefs([]domain.KeyDef{pd}, catalogSlice) + require.Len(t, mismatches, 1) + assert.Equal(t, "postgres.max_open_conns", mismatches[0].CatalogKey) + assert.Equal(t, "ApplyBehavior", mismatches[0].Field) + assert.Equal(t, string(domain.ApplyLiveRead), mismatches[0].CatalogValue) + assert.Equal(t, string(domain.ApplyBundleRebuild), mismatches[0].ProductValue) +} + +func TestValidateKeyDefs_ComponentMismatch(t *testing.T) { + catalogSlice := []SharedKey{ + {Key: "redis.host", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis"}, + } + + pd := keyDefFromShared(catalogSlice[0]) + pd.Component = "cache" // wrong component name + + mismatches := ValidateKeyDefs([]domain.KeyDef{pd}, catalogSlice) + require.Len(t, mismatches, 1) + assert.Equal(t, "Component", mismatches[0].Field) + assert.Equal(t, "redis", mismatches[0].CatalogValue) + assert.Equal(t, "cache", mismatches[0].ProductValue) +} + +func TestValidateKeyDefs_MutableAtRuntimeMismatch(t *testing.T) { + catalogSlice := []SharedKey{ + {Key: "app.log_level", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone}, + } + + pd := keyDefFromShared(catalogSlice[0]) + pd.MutableAtRuntime = false // drift! + + mismatches := ValidateKeyDefs([]domain.KeyDef{pd}, catalogSlice) + require.Len(t, mismatches, 1) + assert.Equal(t, "MutableAtRuntime", mismatches[0].Field) + assert.Equal(t, "true", mismatches[0].CatalogValue) + assert.Equal(t, "false", mismatches[0].ProductValue) +} + +func TestValidateKeyDefs_ValueTypeMismatch(t *testing.T) { + catalogSlice := []SharedKey{ + {Key: "postgres.primary_port", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres"}, + } + + pd := keyDefFromShared(catalogSlice[0]) + pd.ValueType = domain.ValueTypeString // oops, string instead of int + + mismatches := ValidateKeyDefs([]domain.KeyDef{pd}, catalogSlice) + require.Len(t, mismatches, 1) + assert.Equal(t, "ValueType", mismatches[0].Field) + assert.Equal(t, string(domain.ValueTypeInt), mismatches[0].CatalogValue) + assert.Equal(t, string(domain.ValueTypeString), mismatches[0].ProductValue) +} + +func TestValidateKeyDefs_KeyMismatchMatchedByEnvVar(t *testing.T) { + catalogSlice := []SharedKey{ + {Key: "postgres.primary_ssl_mode", EnvVar: "POSTGRES_SSLMODE", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "postgres"}, + } + + pd := keyDefFromShared(catalogSlice[0]) + pd.Key = "postgres.primary_sslmode" + + mismatches := ValidateKeyDefs([]domain.KeyDef{pd}, catalogSlice) + require.Len(t, mismatches, 1) + assert.Equal(t, "Key", mismatches[0].Field) + assert.Equal(t, "postgres.primary_ssl_mode", mismatches[0].CatalogValue) + assert.Equal(t, "postgres.primary_sslmode", mismatches[0].ProductValue) +} + +func TestValidateKeyDefs_KeyMismatchMatchedByAlternateEnvVar(t *testing.T) { + catalogSlice := []SharedKey{ + {Key: "auth.client_secret", MatchEnvVars: []string{"AUTH_CLIENT_SECRET", "PLUGIN_AUTH_CLIENT_SECRET"}, ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, Secret: true, RedactPolicy: domain.RedactFull}, + } + + pd := keyDefFromShared(catalogSlice[0]) + pd.Key = "plugin.auth_client_secret" + pd.EnvVar = "PLUGIN_AUTH_CLIENT_SECRET" + + mismatches := ValidateKeyDefs([]domain.KeyDef{pd}, catalogSlice) + require.Len(t, mismatches, 1) + assert.Equal(t, "Key", mismatches[0].Field) + assert.Equal(t, "auth.client_secret", mismatches[0].CatalogValue) + assert.Equal(t, "plugin.auth_client_secret", mismatches[0].ProductValue) +} + +func TestValidateKeyDefs_EnvVarMismatch(t *testing.T) { + catalogSlice := []SharedKey{ + {Key: "redis.host", EnvVar: "REDIS_HOST", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis"}, + } + + pd := keyDefFromShared(catalogSlice[0]) + pd.EnvVar = "CACHE_HOST" + + mismatches := ValidateKeyDefs([]domain.KeyDef{pd}, catalogSlice) + require.Len(t, mismatches, 1) + assert.Equal(t, "EnvVar", mismatches[0].Field) + assert.Equal(t, "REDIS_HOST", mismatches[0].CatalogValue) + assert.Equal(t, "CACHE_HOST", mismatches[0].ProductValue) +} + +func TestValidateKeyDefs_SecretMismatch(t *testing.T) { + catalogSlice := []SharedKey{{Key: "auth.client_secret", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, Secret: true, RedactPolicy: domain.RedactFull}} + pd := keyDefFromShared(catalogSlice[0]) + pd.Secret = false + + mismatches := ValidateKeyDefs([]domain.KeyDef{pd}, catalogSlice) + require.Len(t, mismatches, 1) + assert.Equal(t, "Secret", mismatches[0].Field) +} + +func TestValidateKeyDefs_RedactPolicyMismatch(t *testing.T) { + catalogSlice := []SharedKey{{Key: "ui.theme", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyLiveRead, Secret: false, RedactPolicy: domain.RedactNone}} + pd := keyDefFromShared(catalogSlice[0]) + pd.RedactPolicy = domain.RedactMask + + mismatches := ValidateKeyDefs([]domain.KeyDef{pd}, catalogSlice) + require.Len(t, mismatches, 1) + assert.Equal(t, "RedactPolicy", mismatches[0].Field) +} + +func TestValidateKeyDefs_RedactPolicy_EmptyAndNoneAreEquivalent(t *testing.T) { + catalogSlice := []SharedKey{{Key: "redis.host", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis", Secret: false, RedactPolicy: domain.RedactNone}} + pd := keyDefFromShared(catalogSlice[0]) + pd.RedactPolicy = "" + + mismatches := ValidateKeyDefs([]domain.KeyDef{pd}, catalogSlice) + assert.Empty(t, mismatches) +} + +func TestValidateKeyDefs_RedactPolicy_SecretImpliesFull(t *testing.T) { + catalogSlice := []SharedKey{{Key: "auth.client_secret", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBootstrapOnly, Secret: true, RedactPolicy: domain.RedactFull}} + pd := keyDefFromShared(catalogSlice[0]) + pd.RedactPolicy = "" + + mismatches := ValidateKeyDefs([]domain.KeyDef{pd}, catalogSlice) + assert.Empty(t, mismatches) +} + +func TestValidateKeyDefs_ProductOnlyKeys_NotFlagged(t *testing.T) { + catalogSlice := []SharedKey{ + {Key: "redis.host", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis"}, + } + + productDefs := []domain.KeyDef{ + keyDefFromShared(catalogSlice[0]), + { + Key: "my_product.custom_feature", + Kind: domain.KindConfig, + AllowedScopes: []domain.Scope{domain.ScopeGlobal}, + ValueType: domain.ValueTypeBool, + ApplyBehavior: domain.ApplyLiveRead, + MutableAtRuntime: true, + Component: domain.ComponentNone, + }, + } + + mismatches := ValidateKeyDefs(productDefs, catalogSlice) + assert.Empty(t, mismatches, "product-only keys must not produce mismatches") +} + +func TestValidateKeyDefs_CatalogOnlyKeys_NotFlagged(t *testing.T) { + catalogSlice := []SharedKey{ + {Key: "redis.host", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis"}, + {Key: "redis.password", ValueType: domain.ValueTypeString, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis"}, + {Key: "redis.db", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyBundleRebuild, MutableAtRuntime: true, Component: "redis"}, + } + + // Product only uses redis.host — the other two catalog keys are unused. + productDefs := []domain.KeyDef{ + keyDefFromShared(catalogSlice[0]), + } + + mismatches := ValidateKeyDefs(productDefs, catalogSlice) + assert.Empty(t, mismatches, "catalog-only keys must not produce mismatches") +} + +func TestValidateKeyDefs_MultipleMismatches(t *testing.T) { + catalogSlice := []SharedKey{ + {Key: "postgres.max_open_conns", ValueType: domain.ValueTypeInt, ApplyBehavior: domain.ApplyLiveRead, MutableAtRuntime: true, Component: domain.ComponentNone}, + } + + pd := keyDefFromShared(catalogSlice[0]) + pd.ApplyBehavior = domain.ApplyBundleRebuild // wrong + pd.MutableAtRuntime = false // wrong + pd.ValueType = domain.ValueTypeString // wrong + pd.Component = "postgres" // wrong (catalog says ComponentNone) + + mismatches := ValidateKeyDefs([]domain.KeyDef{pd}, catalogSlice) + require.Len(t, mismatches, 4, "expected four mismatches for four divergent fields") + + // Verify all fields are represented (sorted by field name within same key). + fields := make(map[string]bool) + for _, m := range mismatches { + fields[m.Field] = true + assert.Equal(t, "postgres.max_open_conns", m.CatalogKey) + } + + assert.True(t, fields["ApplyBehavior"]) + assert.True(t, fields["Component"]) + assert.True(t, fields["MutableAtRuntime"]) + assert.True(t, fields["ValueType"]) +} + +func TestValidateKeyDefs_EmptyInputs(t *testing.T) { + tests := []struct { + name string + productDefs []domain.KeyDef + catalog [][]SharedKey + }{ + {name: "nil product defs", productDefs: nil, catalog: [][]SharedKey{PostgresKeys()}}, + {name: "empty product defs", productDefs: []domain.KeyDef{}, catalog: [][]SharedKey{RedisKeys()}}, + {name: "nil catalog", productDefs: []domain.KeyDef{{Key: "x"}}, catalog: nil}, + {name: "empty catalog slice", productDefs: []domain.KeyDef{{Key: "x"}}, catalog: [][]SharedKey{{}}}, + {name: "both empty", productDefs: nil, catalog: nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mismatches := ValidateKeyDefs(tt.productDefs, tt.catalog...) + assert.Empty(t, mismatches) + }) + } +} + +func TestAllSharedKeys_NoDuplicates(t *testing.T) { + all := AllSharedKeys() + seen := make(map[string]int, len(all)) + + for i, sk := range all { + if prev, exists := seen[sk.Key]; exists { + t.Errorf("duplicate key %q: first at index %d, again at index %d", sk.Key, prev, i) + } + + seen[sk.Key] = i + } +} + +func TestAllSharedKeys_AllValid(t *testing.T) { + for _, sk := range AllSharedKeys() { + t.Run(sk.Key, func(t *testing.T) { + assert.NotEmpty(t, sk.Key, "SharedKey must have a non-empty Key") + assert.True(t, sk.ValueType.IsValid(), "ValueType %q is invalid", sk.ValueType) + assert.True(t, sk.ApplyBehavior.IsValid(), "ApplyBehavior %q is invalid", sk.ApplyBehavior) + assert.True(t, sk.RedactPolicy.IsValid(), "RedactPolicy %q is invalid", sk.RedactPolicy) + if sk.Secret { + assert.NotEqual(t, domain.RedactMask, sk.RedactPolicy, "secret key %q must not use mask redaction", sk.Key) + } + assert.NotEmpty(t, sk.Description, "SharedKey %q must have a Description", sk.Key) + }) + } +} + +func TestAllSharedKeys_NoDuplicateEnvVars(t *testing.T) { + seen := make(map[string]string) + + for _, sk := range AllSharedKeys() { + for _, envVar := range allowedEnvVars(sk) { + if envVar == "" { + continue + } + + if existingKey, ok := seen[envVar]; ok { + t.Fatalf("duplicate env var %q used by %q and %q", envVar, existingKey, sk.Key) + } + + seen[envVar] = sk.Key + } + } +} + +func TestMismatch_String(t *testing.T) { + tests := []struct { + name string + mismatch Mismatch + want string + }{ + { + name: "same key name", + mismatch: Mismatch{ + CatalogKey: "redis.host", + ProductKey: "redis.host", + Field: "ApplyBehavior", + CatalogValue: "live-read", + ProductValue: "bundle-rebuild", + }, + want: `key "redis.host": ApplyBehavior: want live-read, got bundle-rebuild`, + }, + { + name: "different key names", + mismatch: Mismatch{ + CatalogKey: "postgres.primary_host", + ProductKey: "pg.host", + Field: "Component", + CatalogValue: "postgres", + ProductValue: "pg", + }, + want: `key "pg.host" (catalog: "postgres.primary_host"): Component: want postgres, got pg`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.mismatch.String()) + }) + } +} diff --git a/commons/systemplane/domain/coercion_helpers.go b/commons/systemplane/domain/coercion_helpers.go new file mode 100644 index 00000000..6cc69b80 --- /dev/null +++ b/commons/systemplane/domain/coercion_helpers.go @@ -0,0 +1,366 @@ +// Copyright 2025 Lerian Studio. + +package domain + +import ( + "encoding/json" + "fmt" + "math" + "strconv" + "strings" + "time" +) + +type boolStringResult struct { + value bool + ok bool +} + +func coerceString(raw any, fallback string) string { + value, ok := tryCoerceString(raw) + if !ok { + return fallback + } + + return value +} + +func tryCoerceString(raw any) (string, bool) { + if IsNilValue(raw) { + return "", false + } + + switch v := raw.(type) { + case string: + return v, true + case []byte: + return string(v), true + case fmt.Stringer: + return safeStringerString(v) + default: + return safeSprint(raw) + } +} + +func safeStringerString(value fmt.Stringer) (result string, ok bool) { + if IsNilValue(value) { + return "", false + } + + defer func() { + if recover() != nil { + result = "" + ok = false + } + }() + + return value.String(), true +} + +func safeSprint(raw any) (result string, ok bool) { + defer func() { + if recover() != nil { + result = "" + ok = false + } + }() + + return fmt.Sprint(raw), true +} + +func coerceInt(raw any, fallback int) int { + value, ok := tryCoerceInt(raw) + if !ok { + return fallback + } + + return value +} + +func tryCoerceInt(raw any) (int, bool) { + if raw == nil { + return 0, false + } + + switch v := raw.(type) { + case int: + return v, true + case int64: + return intFromInt64(v) + case float64: + return intFromFloat64(v) + case string: + return intFromString(v) + case json.Number: + return intFromJSONNumber(v) + default: + return 0, false + } +} + +func intFromInt64(value int64) (int, bool) { + if value > math.MaxInt || value < math.MinInt { + return 0, false + } + + return int(value), true +} + +// intFromFloat64 converts a float64 to int with overflow/NaN protection. +// Uses float64(math.MaxInt)+1 for platform-independent bounds checking. +func intFromFloat64(value float64) (int, bool) { + if math.IsNaN(value) || math.IsInf(value, 0) { + return 0, false + } + + truncated := math.Trunc(value) + + // float64(math.MaxInt) rounds up on 64-bit platforms, so use >= with + // the exact float representation of MaxInt+1 for platform-independent bounds. + const maxIntPlusOne = float64(math.MaxInt) + 1 + if truncated >= maxIntPlusOne || truncated < math.MinInt { + return 0, false + } + + return int(truncated), true +} + +func intFromString(value string) (int, bool) { + parsed, err := strconv.Atoi(value) + if err != nil { + return 0, false + } + + return parsed, true +} + +func intFromJSONNumber(value json.Number) (int, bool) { + parsed, err := value.Int64() + if err != nil { + return 0, false + } + + return intFromInt64(parsed) +} + +func coerceBool(raw any, fallback bool) bool { + value, ok := tryCoerceBool(raw) + if !ok { + return fallback + } + + return value +} + +func tryCoerceBool(raw any) (bool, bool) { + if raw == nil { + return false, false + } + + switch v := raw.(type) { + case bool: + return v, true + case string: + parsed := parseCanonicalBoolString(v) + return parsed.value, parsed.ok + case int: + switch v { + case 0: + return false, true + case 1: + return true, true + default: + return false, false + } + default: + return false, false + } +} + +func parseCanonicalBoolString(value string) boolStringResult { + switch strings.ToLower(strings.TrimSpace(value)) { + case "true", "1": + return boolStringResult{value: true, ok: true} + case "false", "0": + return boolStringResult{value: false, ok: true} + default: + return boolStringResult{} + } +} + +func coerceFloat64(raw any, fallback float64) float64 { + value, ok := tryCoerceFloat64(raw) + if !ok { + return fallback + } + + return value +} + +func tryCoerceFloat64(raw any) (float64, bool) { + if raw == nil { + return 0, false + } + + switch v := raw.(type) { + case float64: + return finiteFloat64(v) + case int: + return float64(v), true + case int64: + return float64(v), true + case string: + f, err := strconv.ParseFloat(v, 64) + if err != nil { + return 0, false + } + + return finiteFloat64(f) + case json.Number: + f, err := v.Float64() + if err != nil { + return 0, false + } + + return finiteFloat64(f) + default: + return 0, false + } +} + +func finiteFloat64(value float64) (float64, bool) { + if math.IsNaN(value) || math.IsInf(value, 0) { + return 0, false + } + + return value, true +} + +func coerceDuration(raw any, fallback time.Duration, unit time.Duration) time.Duration { + value, ok := tryCoerceDuration(raw, unit) + if !ok { + return fallback + } + + return value +} + +func tryCoerceDuration(raw any, unit time.Duration) (time.Duration, bool) { + if raw == nil { + return 0, false + } + + switch v := raw.(type) { + case int: + return scaleDurationInt64(int64(v), unit) + case int64: + return scaleDurationInt64(v, unit) + case float64: + return scaleDurationFloat64(v, unit) + case string: + d, err := time.ParseDuration(v) + if err != nil { + return 0, false + } + + return d, true + default: + return 0, false + } +} + +func coerceStringSlice(raw any, fallback []string) []string { + value, ok := tryCoerceStringSlice(raw) + if !ok { + return cloneStringSlice(fallback) + } + + return value +} + +func tryCoerceStringSlice(raw any) ([]string, bool) { + if IsNilValue(raw) { + return nil, false + } + + switch v := raw.(type) { + case []string: + result := make([]string, 0, len(v)) + return append(result, v...), true + case []any: + result := make([]string, 0, len(v)) + for _, elem := range v { + stringValue, ok := elem.(string) + if !ok { + return nil, false + } + + trimmed := strings.TrimSpace(stringValue) + if trimmed == "" { + continue + } + + result = append(result, trimmed) + } + + return result, true + case string: + parts := strings.Split(v, ",") + + result := make([]string, 0, len(parts)) + for _, p := range parts { + trimmed := strings.TrimSpace(p) + if trimmed == "" { + continue + } + + result = append(result, trimmed) + } + + return result, true + default: + return nil, false + } +} + +func scaleDurationInt64(value int64, unit time.Duration) (time.Duration, bool) { + if value == 0 || unit == 0 { + return 0, true + } + + scaled := time.Duration(value) * unit + if scaled/unit != time.Duration(value) { + return 0, false + } + + return scaled, true +} + +func scaleDurationFloat64(value float64, unit time.Duration) (time.Duration, bool) { + if math.IsNaN(value) || math.IsInf(value, 0) { + return 0, false + } + + if value == 0 || unit == 0 { + return 0, true + } + + scaled := value * float64(unit) + + // Same boundary alias as intFromFloat64: float64(MaxInt64) rounds up + // on 64-bit. Use float64(math.MaxInt64)+1 for platform-independent bounds. + const maxInt64PlusOne = float64(math.MaxInt64) + 1 + if math.IsNaN(scaled) || math.IsInf(scaled, 0) || scaled >= maxInt64PlusOne || scaled < math.MinInt64 { + return 0, false + } + + return time.Duration(scaled), true +} + +func cloneStringSlice(values []string) []string { + if values == nil { + return nil + } + + return append([]string(nil), values...) +} diff --git a/commons/systemplane/domain/config_helpers.go b/commons/systemplane/domain/config_helpers.go new file mode 100644 index 00000000..4f08e8ef --- /dev/null +++ b/commons/systemplane/domain/config_helpers.go @@ -0,0 +1,43 @@ +// Copyright 2025 Lerian Studio. + +package domain + +import "time" + +// SnapString returns the config value for key coerced to a string, +// or fallback when snap is nil, the key is absent, or coercion fails. +func SnapString(snap *Snapshot, key string, fallback string) string { + return coerceString(snap.ConfigValue(key, nil), fallback) +} + +// SnapInt returns the config value for key coerced to an int, +// or fallback when snap is nil, the key is absent, or coercion fails. +func SnapInt(snap *Snapshot, key string, fallback int) int { + return coerceInt(snap.ConfigValue(key, nil), fallback) +} + +// SnapBool returns the config value for key coerced to a bool, +// or fallback when snap is nil, the key is absent, or coercion fails. +func SnapBool(snap *Snapshot, key string, fallback bool) bool { + return coerceBool(snap.ConfigValue(key, nil), fallback) +} + +// SnapFloat64 returns the config value for key coerced to a float64, +// or fallback when snap is nil, the key is absent, or coercion fails. +func SnapFloat64(snap *Snapshot, key string, fallback float64) float64 { + return coerceFloat64(snap.ConfigValue(key, nil), fallback) +} + +// SnapDuration returns the config value for key coerced to a time.Duration. +// Numeric values are multiplied by unit (e.g. time.Second). String values must +// be parseable by time.ParseDuration. Returns fallback when snap is nil, the +// key is absent, or coercion fails. +func SnapDuration(snap *Snapshot, key string, fallback time.Duration, unit time.Duration) time.Duration { + return coerceDuration(snap.ConfigValue(key, nil), fallback, unit) +} + +// SnapStringSlice returns the config value for key coerced to a []string, +// or fallback when snap is nil, the key is absent, or coercion fails. +func SnapStringSlice(snap *Snapshot, key string, fallback []string) []string { + return coerceStringSlice(snap.ConfigValue(key, nil), fallback) +} diff --git a/commons/systemplane/domain/key_def.go b/commons/systemplane/domain/key_def.go index 85e7abfa..59693693 100644 --- a/commons/systemplane/domain/key_def.go +++ b/commons/systemplane/domain/key_def.go @@ -5,6 +5,7 @@ package domain import ( "errors" "fmt" + "regexp" ) // ValueType classifies the data type of a configuration value. @@ -23,6 +24,14 @@ const ( // ErrInvalidValueType indicates an invalid value type. var ErrInvalidValueType = errors.New("invalid value type") +// ErrInvalidRedactPolicy indicates an invalid redact policy. +var ErrInvalidRedactPolicy = errors.New("invalid redact policy") + +// ErrInvalidEnvVar indicates an invalid environment variable name. +var ErrInvalidEnvVar = errors.New("invalid env var") + +var envVarPattern = regexp.MustCompile(`^[A-Z][A-Z0-9_]*$`) + // IsValid reports whether the value type is supported. func (vt ValueType) IsValid() bool { switch vt { @@ -56,6 +65,18 @@ const ( RedactMask RedactPolicy = "mask" ) +// IsValid reports whether the redact policy is supported. +// The zero value is treated as valid for backward compatibility and is +// interpreted the same as RedactNone by the service layer. +func (rp RedactPolicy) IsValid() bool { + switch rp { + case "", RedactNone, RedactFull, RedactMask: + return true + default: + return false + } +} + // ValidatorFunc is a custom validation function for a key's value. It returns // a non-nil error when the value is invalid. type ValidatorFunc func(value any) error @@ -71,6 +92,7 @@ const ComponentNone = "_none" // the key's type, visibility, constraints, and runtime behavior. type KeyDef struct { Key string + EnvVar string Kind Kind AllowedScopes []Scope DefaultValue any @@ -121,6 +143,18 @@ func (keyDef KeyDef) Validate() error { return fmt.Errorf("key def %q value type %q: %w", keyDef.Key, keyDef.ValueType, ErrInvalidValueType) } + if !keyDef.RedactPolicy.IsValid() { + return fmt.Errorf("key def %q redact policy %q: %w", keyDef.Key, keyDef.RedactPolicy, ErrInvalidRedactPolicy) + } + + if keyDef.EnvVar != "" && !envVarPattern.MatchString(keyDef.EnvVar) { + return fmt.Errorf("key def %q env var %q: %w", keyDef.Key, keyDef.EnvVar, ErrInvalidEnvVar) + } + + // Secret keys are always treated as RedactFull at runtime (see + // normalizeRedactPolicy in catalog/validate.go), so any RedactPolicy + // declared on a secret key is acceptable — no error needed here. + if !keyDef.ApplyBehavior.IsValid() { return fmt.Errorf("key def %q apply behavior %q: %w", keyDef.Key, keyDef.ApplyBehavior, ErrInvalidApplyBehavior) } diff --git a/commons/systemplane/domain/key_def_test.go b/commons/systemplane/domain/key_def_test.go index 42665a7a..9a644565 100644 --- a/commons/systemplane/domain/key_def_test.go +++ b/commons/systemplane/domain/key_def_test.go @@ -106,6 +106,69 @@ func TestKeyDef_Validate_InvalidApplyBehavior(t *testing.T) { assert.ErrorIs(t, err, ErrInvalidApplyBehavior) } +func TestKeyDef_Validate_InvalidRedactPolicy(t *testing.T) { + t.Parallel() + + kd := validKeyDef() + kd.RedactPolicy = RedactPolicy("bogus") + + err := kd.Validate() + + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidRedactPolicy) +} + +func TestKeyDef_Validate_InvalidEnvVar(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + envVar string + }{ + {name: "lowercase", envVar: "postgres_host"}, + {name: "dash", envVar: "POSTGRES-HOST"}, + {name: "leading space", envVar: " POSTGRES_HOST"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + kd := validKeyDef() + kd.EnvVar = tt.envVar + + err := kd.Validate() + + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidEnvVar) + }) + } +} + +func TestKeyDef_Validate_ValidEnvVar(t *testing.T) { + t.Parallel() + + kd := validKeyDef() + kd.EnvVar = "POSTGRES_HOST" + + require.NoError(t, kd.Validate()) +} + +func TestKeyDef_Validate_SecretMaskAccepted(t *testing.T) { + t.Parallel() + + // Secret+RedactMask is accepted by Validate (not an error). The runtime + // normalizes it to RedactFull via normalizeRedactPolicy in the catalog + // and snapshot layers. + kd := validKeyDef() + kd.Secret = true + kd.RedactPolicy = RedactMask + + err := kd.Validate() + + require.NoError(t, err) +} + func TestKeyDef_Validate_EmptyAllowedScopes(t *testing.T) { t.Parallel() diff --git a/commons/systemplane/domain/setting_helpers.go b/commons/systemplane/domain/setting_helpers.go new file mode 100644 index 00000000..62c10573 --- /dev/null +++ b/commons/systemplane/domain/setting_helpers.go @@ -0,0 +1,69 @@ +// Copyright 2025 Lerian Studio. + +package domain + +// SnapSettingString returns a setting value coerced to string, cascading from +// tenant-scoped to global to fallback. Nil-safe: returns fallback when snap is nil. +func SnapSettingString(snap *Snapshot, tenantID, key string, fallback string) string { + if snap == nil { + return fallback + } + + if raw, ok := snap.GetTenantSetting(tenantID, key); ok { + if value, converted := tryCoerceString(raw.Value); converted { + return value + } + } + + if raw, ok := snap.GetGlobalSetting(key); ok { + if value, converted := tryCoerceString(raw.Value); converted { + return value + } + } + + return fallback +} + +// SnapSettingInt returns a setting value coerced to int, cascading from +// tenant-scoped to global to fallback. Nil-safe: returns fallback when snap is nil. +func SnapSettingInt(snap *Snapshot, tenantID, key string, fallback int) int { + if snap == nil { + return fallback + } + + if raw, ok := snap.GetTenantSetting(tenantID, key); ok { + if value, converted := tryCoerceInt(raw.Value); converted { + return value + } + } + + if raw, ok := snap.GetGlobalSetting(key); ok { + if value, converted := tryCoerceInt(raw.Value); converted { + return value + } + } + + return fallback +} + +// SnapSettingBool returns a setting value coerced to bool, cascading from +// tenant-scoped to global to fallback. Nil-safe: returns fallback when snap is nil. +func SnapSettingBool(snap *Snapshot, tenantID, key string, fallback bool) bool { + if snap == nil { + return fallback + } + + if raw, ok := snap.GetTenantSetting(tenantID, key); ok { + if value, converted := tryCoerceBool(raw.Value); converted { + return value + } + } + + if raw, ok := snap.GetGlobalSetting(key); ok { + if value, converted := tryCoerceBool(raw.Value); converted { + return value + } + } + + return fallback +} diff --git a/commons/systemplane/domain/snapshot.go b/commons/systemplane/domain/snapshot.go index f0f8e93c..71918326 100644 --- a/commons/systemplane/domain/snapshot.go +++ b/commons/systemplane/domain/snapshot.go @@ -64,8 +64,12 @@ func (s *Snapshot) GetTenantSetting(tenantID, key string) (EffectiveValue, bool) } // ConfigValue returns the configuration value for the given key, or the -// fallback if the key is not present. +// fallback if the key is not present. Nil-receiver safe. func (s *Snapshot) ConfigValue(key string, fallback any) any { + if s == nil { + return fallback + } + if v, ok := s.GetConfig(key); ok { return v.Value } @@ -74,8 +78,12 @@ func (s *Snapshot) ConfigValue(key string, fallback any) any { } // GlobalSettingValue returns the setting value for the given key, or the fallback -// if the key is not present. +// if the key is not present. Nil-receiver safe. func (s *Snapshot) GlobalSettingValue(key string, fallback any) any { + if s == nil { + return fallback + } + if v, ok := s.GetGlobalSetting(key); ok { return v.Value } @@ -84,8 +92,12 @@ func (s *Snapshot) GlobalSettingValue(key string, fallback any) any { } // TenantSettingValue returns the tenant setting value for the given key, or the -// fallback if the key is not present. +// fallback if the key is not present. Nil-receiver safe. func (s *Snapshot) TenantSettingValue(tenantID, key string, fallback any) any { + if s == nil { + return fallback + } + if v, ok := s.GetTenantSetting(tenantID, key); ok { return v.Value } diff --git a/commons/systemplane/domain/snapshot_config_helpers_test.go b/commons/systemplane/domain/snapshot_config_helpers_test.go new file mode 100644 index 00000000..3706c260 --- /dev/null +++ b/commons/systemplane/domain/snapshot_config_helpers_test.go @@ -0,0 +1,224 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package domain + +import ( + "encoding/json" + "fmt" + "math" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSnapString(t *testing.T) { + t.Parallel() + + var nilStringer *panicStringer + var nilInt *int + panickingStringer := &panicStringer{} + + tests := []struct { + name string + snap *Snapshot + key string + fallback string + want string + }{ + {name: "nil snapshot", snap: nil, key: "k", fallback: "fb", want: "fb"}, + {name: "missing key", snap: snapWith("other", "v"), key: "k", fallback: "fb", want: "fb"}, + {name: "direct string", snap: snapWith("k", "hello"), key: "k", fallback: "fb", want: "hello"}, + {name: "empty string is not missing", snap: snapWith("k", ""), key: "k", fallback: "fb", want: ""}, + {name: "[]byte", snap: snapWith("k", []byte("bytes")), key: "k", fallback: "fb", want: "bytes"}, + {name: "fmt.Stringer", snap: snapWith("k", stringer{s: "custom"}), key: "k", fallback: "fb", want: "custom"}, + {name: "typed nil fmt.Stringer falls back", snap: snapWith("k", nilStringer), key: "k", fallback: "fb", want: "fb"}, + {name: "panicking fmt.Stringer falls back", snap: snapWith("k", panickingStringer), key: "k", fallback: "fb", want: "fb"}, + {name: "typed nil pointer falls back", snap: snapWith("k", nilInt), key: "k", fallback: "fb", want: "fb"}, + {name: "int via Sprint", snap: snapWith("k", 42), key: "k", fallback: "fb", want: "42"}, + {name: "bool via Sprint", snap: snapWith("k", true), key: "k", fallback: "fb", want: "true"}, + {name: "nil value in config", snap: snapWith("k", nil), key: "k", fallback: "fb", want: "fb"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.NotPanics(t, func() { + got := SnapString(tt.snap, tt.key, tt.fallback) + assert.Equal(t, tt.want, got) + }) + }) + } +} + +func TestSnapInt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + snap *Snapshot + key string + fallback int + want int + }{ + {name: "nil snapshot", snap: nil, key: "k", fallback: -1, want: -1}, + {name: "missing key", snap: snapWith("other", 1), key: "k", fallback: -1, want: -1}, + {name: "direct int", snap: snapWith("k", 42), key: "k", fallback: -1, want: 42}, + {name: "zero int is not missing", snap: snapWith("k", 0), key: "k", fallback: -1, want: 0}, + {name: "int64", snap: snapWith("k", int64(100)), key: "k", fallback: -1, want: 100}, + {name: "float64 whole", snap: snapWith("k", float64(7)), key: "k", fallback: -1, want: 7}, + {name: "float64 fractional truncates", snap: snapWith("k", 7.9), key: "k", fallback: -1, want: 7}, + {name: "float64 negative truncates toward zero", snap: snapWith("k", -7.9), key: "k", fallback: -1, want: -7}, + {name: "string parseable", snap: snapWith("k", "99"), key: "k", fallback: -1, want: 99}, + {name: "string unparseable", snap: snapWith("k", "abc"), key: "k", fallback: -1, want: -1}, + {name: "json.Number valid", snap: snapWith("k", json.Number("123")), key: "k", fallback: -1, want: 123}, + {name: "json.Number invalid", snap: snapWith("k", json.Number("nope")), key: "k", fallback: -1, want: -1}, + {name: "bool returns fallback", snap: snapWith("k", true), key: "k", fallback: -1, want: -1}, + {name: "nil value", snap: snapWith("k", nil), key: "k", fallback: -1, want: -1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := SnapInt(tt.snap, tt.key, tt.fallback) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestSnapInt_OverflowAndSpecialFloatFallbacks(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + val any + }{ + {name: "NaN", val: math.NaN()}, + {name: "+Inf", val: math.Inf(1)}, + {name: "-Inf", val: math.Inf(-1)}, + {name: "json.Number overflow", val: json.Number("999999999999999999999999")}, + } + + if strconv.IntSize == 32 { + tests = append(tests, struct { + name string + val any + }{name: "int64 overflow on 32-bit", val: int64(math.MaxInt64)}) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := SnapInt(snapWith("k", tt.val), "k", -1) + assert.Equal(t, -1, got) + }) + } +} + +func TestSnapBool(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + snap *Snapshot + key string + fallback bool + want bool + }{ + {name: "nil snapshot", snap: nil, key: "k", fallback: true, want: true}, + {name: "missing key", snap: snapWith("other", true), key: "k", fallback: true, want: true}, + {name: "direct true", snap: snapWith("k", true), key: "k", fallback: false, want: true}, + {name: "direct false is not missing", snap: snapWith("k", false), key: "k", fallback: true, want: false}, + {name: "string true", snap: snapWith("k", "true"), key: "k", fallback: false, want: true}, + {name: "string false", snap: snapWith("k", "false"), key: "k", fallback: true, want: false}, + {name: "string 1", snap: snapWith("k", "1"), key: "k", fallback: false, want: true}, + {name: "string 0", snap: snapWith("k", "0"), key: "k", fallback: true, want: false}, + {name: "string t falls back", snap: snapWith("k", "t"), key: "k", fallback: false, want: false}, + {name: "string f falls back", snap: snapWith("k", "f"), key: "k", fallback: true, want: true}, + {name: "string TRUE (case-insensitive)", snap: snapWith("k", "TRUE"), key: "k", fallback: false, want: true}, + {name: "string True (case-insensitive)", snap: snapWith("k", "True"), key: "k", fallback: false, want: true}, + {name: "string FALSE (case-insensitive)", snap: snapWith("k", "FALSE"), key: "k", fallback: true, want: false}, + {name: "string False (case-insensitive)", snap: snapWith("k", "False"), key: "k", fallback: true, want: false}, + {name: "string unparseable", snap: snapWith("k", "maybe"), key: "k", fallback: false, want: false}, + {name: "int non-zero", snap: snapWith("k", 1), key: "k", fallback: false, want: true}, + {name: "int two falls back", snap: snapWith("k", 2), key: "k", fallback: false, want: false}, + {name: "int negative falls back", snap: snapWith("k", -1), key: "k", fallback: true, want: true}, + {name: "int zero", snap: snapWith("k", 0), key: "k", fallback: true, want: false}, + {name: "int64 returns fallback", snap: snapWith("k", int64(42)), key: "k", fallback: false, want: false}, + {name: "float64 returns fallback", snap: snapWith("k", 3.14), key: "k", fallback: true, want: true}, + {name: "nil value", snap: snapWith("k", nil), key: "k", fallback: true, want: true}, + {name: "slice returns fallback", snap: snapWith("k", []string{"a"}), key: "k", fallback: true, want: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := SnapBool(tt.snap, tt.key, tt.fallback) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestSnapFloat64(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + snap *Snapshot + key string + fallback float64 + want float64 + }{ + {name: "nil snapshot", snap: nil, key: "k", fallback: -1, want: -1}, + {name: "missing key", snap: snapWith("other", 1.0), key: "k", fallback: -1, want: -1}, + {name: "direct float64", snap: snapWith("k", 3.14), key: "k", fallback: -1, want: 3.14}, + {name: "zero float is not missing", snap: snapWith("k", float64(0)), key: "k", fallback: -1, want: 0}, + {name: "int", snap: snapWith("k", 7), key: "k", fallback: -1, want: 7.0}, + {name: "int64", snap: snapWith("k", int64(100)), key: "k", fallback: -1, want: 100.0}, + {name: "string parseable", snap: snapWith("k", "2.718"), key: "k", fallback: -1, want: 2.718}, + {name: "string unparseable", snap: snapWith("k", "abc"), key: "k", fallback: -1, want: -1}, + {name: "json.Number valid", snap: snapWith("k", json.Number("9.81")), key: "k", fallback: -1, want: 9.81}, + {name: "json.Number invalid", snap: snapWith("k", json.Number("bad")), key: "k", fallback: -1, want: -1}, + {name: "direct NaN falls back", snap: snapWith("k", math.NaN()), key: "k", fallback: -1, want: -1}, + {name: "direct +Inf falls back", snap: snapWith("k", math.Inf(1)), key: "k", fallback: -1, want: -1}, + {name: "direct -Inf falls back", snap: snapWith("k", math.Inf(-1)), key: "k", fallback: -1, want: -1}, + {name: "string NaN falls back", snap: snapWith("k", "NaN"), key: "k", fallback: -1, want: -1}, + {name: "string +Inf falls back", snap: snapWith("k", "+Inf"), key: "k", fallback: -1, want: -1}, + {name: "bool returns fallback", snap: snapWith("k", true), key: "k", fallback: -1, want: -1}, + {name: "nil value", snap: snapWith("k", nil), key: "k", fallback: -1, want: -1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := SnapFloat64(tt.snap, tt.key, tt.fallback) + assert.InDelta(t, tt.want, got, 1e-9) + }) + } +} + +func TestSnapString_StringerInterface(t *testing.T) { + t.Parallel() + + snap := snapWith("k", stringer{s: "via-stringer"}) + got := SnapString(snap, "k", "fb") + assert.Equal(t, "via-stringer", got) +} + +func TestSnapStringSlice_ConfigValueIsCloned(t *testing.T) { + t.Parallel() + + original := []string{"a", "b"} + got := SnapStringSlice(snapWith("k", original), "k", nil) + got[0] = "changed" + assert.Equal(t, []string{"a", "b"}, original) +} + +func TestSnapConfigHelpers_FmtDescriptionsStayStable(t *testing.T) { + t.Parallel() + + assert.Equal(t, "7", SnapString(snapWith("k", 7), "k", "fb")) + assert.Equal(t, fmt.Sprint(true), SnapString(snapWith("k", true), "k", "fb")) +} diff --git a/commons/systemplane/domain/snapshot_duration_slice_helpers_test.go b/commons/systemplane/domain/snapshot_duration_slice_helpers_test.go new file mode 100644 index 00000000..373c5589 --- /dev/null +++ b/commons/systemplane/domain/snapshot_duration_slice_helpers_test.go @@ -0,0 +1,140 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package domain + +import ( + "math" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSnapDuration(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + snap *Snapshot + key string + fallback time.Duration + unit time.Duration + want time.Duration + }{ + {name: "nil snapshot", snap: nil, key: "k", fallback: 5 * time.Second, unit: time.Second, want: 5 * time.Second}, + {name: "missing key", snap: snapWith("other", 1), key: "k", fallback: 5 * time.Second, unit: time.Second, want: 5 * time.Second}, + {name: "int seconds", snap: snapWith("k", 30), key: "k", fallback: 0, unit: time.Second, want: 30 * time.Second}, + {name: "int64 millis", snap: snapWith("k", int64(500)), key: "k", fallback: 0, unit: time.Millisecond, want: 500 * time.Millisecond}, + {name: "float64 seconds", snap: snapWith("k", 1.5), key: "k", fallback: 0, unit: time.Second, want: 1500 * time.Millisecond}, + {name: "zero int is not fallback", snap: snapWith("k", 0), key: "k", fallback: 5 * time.Second, unit: time.Second, want: 0}, + {name: "zero float is not fallback", snap: snapWith("k", 0.0), key: "k", fallback: 5 * time.Second, unit: time.Second, want: 0}, + {name: "string parseable duration", snap: snapWith("k", "2m30s"), key: "k", fallback: 0, unit: time.Second, want: 2*time.Minute + 30*time.Second}, + {name: "string numeric returns fallback", snap: snapWith("k", "10"), key: "k", fallback: 3 * time.Second, unit: time.Second, want: 3 * time.Second}, + {name: "string unparseable", snap: snapWith("k", "nope"), key: "k", fallback: 3 * time.Second, unit: time.Second, want: 3 * time.Second}, + {name: "nil value", snap: snapWith("k", nil), key: "k", fallback: 7 * time.Second, unit: time.Second, want: 7 * time.Second}, + {name: "bool returns fallback", snap: snapWith("k", true), key: "k", fallback: 4 * time.Second, unit: time.Second, want: 4 * time.Second}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := SnapDuration(tt.snap, tt.key, tt.fallback, tt.unit) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestSnapDuration_Float64NaNInf(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value float64 + }{ + {name: "NaN", value: math.NaN()}, + {name: "+Inf", value: math.Inf(1)}, + {name: "-Inf", value: math.Inf(-1)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := SnapDuration(snapWith("k", tt.value), "k", 5*time.Second, time.Second) + assert.Equal(t, 5*time.Second, got) + }) + } +} + +func TestSnapDuration_OverflowFallsBack(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + val any + }{ + {name: "int64 overflow", val: int64(math.MaxInt64)}, + {name: "float64 overflow", val: float64(math.MaxInt64)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := SnapDuration(snapWith("k", tt.val), "k", 5*time.Second, time.Second) + assert.Equal(t, 5*time.Second, got) + }) + } +} + +func TestSnapStringSlice(t *testing.T) { + t.Parallel() + + fb := []string{"fallback"} + + tests := []struct { + name string + snap *Snapshot + key string + fallback []string + want []string + }{ + {name: "nil snapshot", snap: nil, key: "k", fallback: fb, want: fb}, + {name: "missing key", snap: snapWith("other", "x"), key: "k", fallback: fb, want: fb}, + {name: "direct []string", snap: snapWith("k", []string{"a", "b"}), key: "k", fallback: fb, want: []string{"a", "b"}}, + {name: "empty []string is not missing", snap: snapWith("k", []string{}), key: "k", fallback: fb, want: []string{}}, + {name: "typed nil []string falls back", snap: snapWith("k", []string(nil)), key: "k", fallback: fb, want: fb}, + {name: "typed nil []any falls back", snap: snapWith("k", []any(nil)), key: "k", fallback: fb, want: fb}, + {name: "[]any strings are trimmed", snap: snapWith("k", []any{"x", " y ", ""}), key: "k", fallback: fb, want: []string{"x", "y"}}, + {name: "[]any non-string falls back", snap: snapWith("k", []any{"x", 1, true}), key: "k", fallback: fb, want: fb}, + {name: "comma separated string", snap: snapWith("k", "a, b, c"), key: "k", fallback: fb, want: []string{"a", "b", "c"}}, + {name: "comma separated string filters blanks", snap: snapWith("k", "a, , b,,"), key: "k", fallback: fb, want: []string{"a", "b"}}, + {name: "empty string becomes empty slice", snap: snapWith("k", ""), key: "k", fallback: fb, want: []string{}}, + {name: "single element string", snap: snapWith("k", "only"), key: "k", fallback: fb, want: []string{"only"}}, + {name: "nil value", snap: snapWith("k", nil), key: "k", fallback: fb, want: fb}, + {name: "int returns fallback", snap: snapWith("k", 42), key: "k", fallback: fb, want: fb}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := SnapStringSlice(tt.snap, tt.key, tt.fallback) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestSnapStringSlice_FallbackIsCloned(t *testing.T) { + t.Parallel() + + fallback := []string{"fallback"} + for _, got := range [][]string{ + SnapStringSlice(nil, "k", fallback), + SnapStringSlice(&Snapshot{}, "k", fallback), + SnapStringSlice(snapWith("k", 42), "k", fallback), + } { + got[0] = "changed" + } + + assert.Equal(t, []string{"fallback"}, fallback) +} diff --git a/commons/systemplane/domain/snapshot_from_keydefs.go b/commons/systemplane/domain/snapshot_from_keydefs.go new file mode 100644 index 00000000..616ca222 --- /dev/null +++ b/commons/systemplane/domain/snapshot_from_keydefs.go @@ -0,0 +1,69 @@ +// Copyright 2025 Lerian Studio. + +package domain + +import "time" + +// DefaultSnapshotFromKeyDefs builds a Snapshot pre-seeded with the default +// values declared in the given KeyDefs. Only KindConfig definitions are +// included; settings and tenant-scoped keys are skipped. +// +// The resulting snapshot has RevisionZero and Source "registry-default" for +// every entry. This is useful for tests, pre-store bootstrap, and any +// context where a SnapshotBuilder (which requires a live Store) is not yet +// available. +func DefaultSnapshotFromKeyDefs(defs []KeyDef) Snapshot { + configs := make(map[string]EffectiveValue, len(defs)) + + for _, def := range defs { + if def.Kind != KindConfig { + continue + } + + configs[def.Key] = EffectiveValue{ + Key: def.Key, + Value: cloneRuntimeValue(def.DefaultValue), + Default: cloneRuntimeValue(def.DefaultValue), + Source: "registry-default", + Revision: RevisionZero, + Redacted: def.Secret || (def.RedactPolicy != "" && def.RedactPolicy != RedactNone), + } + } + + return Snapshot{ + Configs: configs, + GlobalSettings: make(map[string]EffectiveValue), + TenantSettings: make(map[string]map[string]EffectiveValue), + BuiltAt: time.Now().UTC(), + } +} + +// cloneRuntimeValue returns a deep copy for map[string]any and []any values +// to prevent aliasing between the runtime Value and the stored Default in +// an EffectiveValue. Scalar types (string, int, bool, etc.) are returned as-is +// since they are immutable. +// +// This intentionally uses type assertions (not reflection) because systemplane +// config values are always JSON-decoded into map[string]any / []any. The +// narrower scope avoids the cost and complexity of reflect-based cloning in +// manager_helpers.go, which handles arbitrary Go types for reconciler bundles. +func cloneRuntimeValue(v any) any { + switch x := v.(type) { + case map[string]any: + out := make(map[string]any, len(x)) + for k, vv := range x { + out[k] = cloneRuntimeValue(vv) + } + + return out + case []any: + out := make([]any, len(x)) + for i, vv := range x { + out[i] = cloneRuntimeValue(vv) + } + + return out + default: + return v + } +} diff --git a/commons/systemplane/domain/snapshot_from_keydefs_test.go b/commons/systemplane/domain/snapshot_from_keydefs_test.go new file mode 100644 index 00000000..9fd2f70a --- /dev/null +++ b/commons/systemplane/domain/snapshot_from_keydefs_test.go @@ -0,0 +1,121 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package domain + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultSnapshotFromKeyDefs_PopulatesConfigKeys(t *testing.T) { + t.Parallel() + + defs := []KeyDef{ + {Key: "app.log_level", Kind: KindConfig, DefaultValue: "info", ValueType: ValueTypeString, ApplyBehavior: ApplyLiveRead, AllowedScopes: []Scope{ScopeGlobal}}, + {Key: "rate_limit.max", Kind: KindConfig, DefaultValue: 100, ValueType: ValueTypeInt, ApplyBehavior: ApplyLiveRead, AllowedScopes: []Scope{ScopeGlobal}}, + } + + snap := DefaultSnapshotFromKeyDefs(defs) + + require.Len(t, snap.Configs, 2) + + logLevel, ok := snap.Configs["app.log_level"] + require.True(t, ok) + assert.Equal(t, "info", logLevel.Value) + assert.Equal(t, "info", logLevel.Default) + assert.Equal(t, "registry-default", logLevel.Source) + assert.Equal(t, RevisionZero, logLevel.Revision) + assert.False(t, logLevel.Redacted) + + rateLimit, ok := snap.Configs["rate_limit.max"] + require.True(t, ok) + assert.Equal(t, 100, rateLimit.Value) +} + +func TestDefaultSnapshotFromKeyDefs_SkipsNonConfigKinds(t *testing.T) { + t.Parallel() + + defs := []KeyDef{ + {Key: "config.key", Kind: KindConfig, DefaultValue: "v", ValueType: ValueTypeString, ApplyBehavior: ApplyLiveRead, AllowedScopes: []Scope{ScopeGlobal}}, + {Key: "setting.key", Kind: KindSetting, DefaultValue: "v", ValueType: ValueTypeString, ApplyBehavior: ApplyLiveRead, AllowedScopes: []Scope{ScopeGlobal}}, + } + + snap := DefaultSnapshotFromKeyDefs(defs) + + require.Len(t, snap.Configs, 1) + _, ok := snap.Configs["config.key"] + assert.True(t, ok) + _, ok = snap.Configs["setting.key"] + assert.False(t, ok) +} + +func TestDefaultSnapshotFromKeyDefs_MarksRedactedKeys(t *testing.T) { + t.Parallel() + + defs := []KeyDef{ + {Key: "secret.key", Kind: KindConfig, DefaultValue: "", ValueType: ValueTypeString, ApplyBehavior: ApplyLiveRead, AllowedScopes: []Scope{ScopeGlobal}, RedactPolicy: RedactFull}, + {Key: "public.key", Kind: KindConfig, DefaultValue: "", ValueType: ValueTypeString, ApplyBehavior: ApplyLiveRead, AllowedScopes: []Scope{ScopeGlobal}, RedactPolicy: RedactNone}, + {Key: "masked.key", Kind: KindConfig, DefaultValue: "", ValueType: ValueTypeString, ApplyBehavior: ApplyLiveRead, AllowedScopes: []Scope{ScopeGlobal}, RedactPolicy: RedactMask}, + {Key: "empty.policy", Kind: KindConfig, DefaultValue: "", ValueType: ValueTypeString, ApplyBehavior: ApplyLiveRead, AllowedScopes: []Scope{ScopeGlobal}}, + } + + snap := DefaultSnapshotFromKeyDefs(defs) + + assert.True(t, snap.Configs["secret.key"].Redacted) + assert.False(t, snap.Configs["public.key"].Redacted) + assert.True(t, snap.Configs["masked.key"].Redacted) + assert.False(t, snap.Configs["empty.policy"].Redacted) +} + +func TestDefaultSnapshotFromKeyDefs_EmptyInput(t *testing.T) { + t.Parallel() + + snap := DefaultSnapshotFromKeyDefs(nil) + + assert.Empty(t, snap.Configs) + assert.False(t, snap.BuiltAt.IsZero()) +} + +func TestDefaultSnapshotFromKeyDefs_SetsBuiltAt(t *testing.T) { + t.Parallel() + + snap := DefaultSnapshotFromKeyDefs([]KeyDef{ + {Key: "k", Kind: KindConfig, DefaultValue: "v", ValueType: ValueTypeString, ApplyBehavior: ApplyLiveRead, AllowedScopes: []Scope{ScopeGlobal}}, + }) + + assert.False(t, snap.BuiltAt.IsZero()) + // Use zone offset (0 == UTC) instead of Location().String() to avoid + // platform-specific timezone name differences (e.g., "UTC" vs ""). + _, offset := snap.BuiltAt.Zone() + assert.Equal(t, 0, offset) +} + +func TestDefaultSnapshotFromKeyDefs_SecretWithEmptyRedactPolicy(t *testing.T) { + t.Parallel() + + // A key with Secret=true and no explicit RedactPolicy (zero value ""). + // The production code sets Redacted = def.Secret || (def.RedactPolicy != "" && def.RedactPolicy != RedactNone). + // With Secret=true, Redacted must be true regardless of RedactPolicy. + defs := []KeyDef{ + { + Key: "auth.token", + Kind: KindConfig, + DefaultValue: "", + ValueType: ValueTypeString, + ApplyBehavior: ApplyBootstrapOnly, + AllowedScopes: []Scope{ScopeGlobal}, + Secret: true, + RedactPolicy: "", // explicitly empty — the "no policy declared" case + }, + } + + snap := DefaultSnapshotFromKeyDefs(defs) + + entry, ok := snap.Configs["auth.token"] + require.True(t, ok) + assert.True(t, entry.Redacted, "Secret=true must force Redacted=true even when RedactPolicy is empty") +} diff --git a/commons/systemplane/domain/snapshot_setting_helpers_test.go b/commons/systemplane/domain/snapshot_setting_helpers_test.go new file mode 100644 index 00000000..559fd11a --- /dev/null +++ b/commons/systemplane/domain/snapshot_setting_helpers_test.go @@ -0,0 +1,105 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package domain + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSnapSettingString(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + snap *Snapshot + tenant string + key string + fallback string + want string + }{ + {name: "nil snapshot", snap: nil, tenant: "t1", key: "k", fallback: "fb", want: "fb"}, + {name: "tenant value wins", snap: snapWithBoth("t1", "k", "global", "tenant"), tenant: "t1", key: "k", fallback: "fb", want: "tenant"}, + {name: "falls through to global", snap: snapWithGlobal("k", "global"), tenant: "t1", key: "k", fallback: "fb", want: "global"}, + {name: "falls through to fallback", snap: &Snapshot{}, tenant: "t1", key: "k", fallback: "fb", want: "fb"}, + {name: "tenant value with type coercion", snap: snapWithTenant("t1", "k", 42), tenant: "t1", key: "k", fallback: "fb", want: "42"}, + {name: "tenant empty string overrides global", snap: snapWithBoth("t1", "k", "global", ""), tenant: "t1", key: "k", fallback: "fb", want: ""}, + {name: "tenant nil falls through to global", snap: snapWithBoth("t1", "k", "global", nil), tenant: "t1", key: "k", fallback: "fb", want: "global"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := SnapSettingString(tt.snap, tt.tenant, tt.key, tt.fallback) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestSnapSettingInt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + snap *Snapshot + tenant string + key string + fallback int + want int + }{ + {name: "nil snapshot", snap: nil, tenant: "t1", key: "k", fallback: -1, want: -1}, + {name: "tenant value wins", snap: snapWithBoth("t1", "k", 10, 20), tenant: "t1", key: "k", fallback: -1, want: 20}, + {name: "falls through to global", snap: snapWithGlobal("k", 10), tenant: "t1", key: "k", fallback: -1, want: 10}, + {name: "falls through to fallback", snap: &Snapshot{}, tenant: "t1", key: "k", fallback: -1, want: -1}, + {name: "global with string coercion", snap: snapWithGlobal("k", "77"), tenant: "t1", key: "k", fallback: -1, want: 77}, + {name: "tenant nil falls through to global", snap: snapWithBoth("t1", "k", 7, nil), tenant: "t1", key: "k", fallback: -1, want: 7}, + {name: "tenant malformed falls through to global", snap: snapWithBoth("t1", "k", 7, []string{"bad"}), tenant: "t1", key: "k", fallback: -1, want: 7}, + {name: "global malformed falls through to fallback", snap: snapWithGlobal("k", []string{"bad"}), tenant: "t1", key: "k", fallback: -1, want: -1}, + {name: "tenant zero overrides global", snap: snapWithBoth("t1", "k", 7, 0), tenant: "t1", key: "k", fallback: -1, want: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := SnapSettingInt(tt.snap, tt.tenant, tt.key, tt.fallback) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestSnapSettingBool(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + snap *Snapshot + tenant string + key string + fallback bool + want bool + }{ + {name: "nil snapshot", snap: nil, tenant: "t1", key: "k", fallback: false, want: false}, + {name: "tenant value wins", snap: snapWithBoth("t1", "k", false, true), tenant: "t1", key: "k", fallback: false, want: true}, + {name: "falls through to global", snap: snapWithGlobal("k", true), tenant: "t1", key: "k", fallback: false, want: true}, + {name: "falls through to fallback", snap: &Snapshot{}, tenant: "t1", key: "k", fallback: true, want: true}, + {name: "tenant string coercion", snap: snapWithTenant("t1", "k", "true"), tenant: "t1", key: "k", fallback: false, want: true}, + {name: "tenant invalid bool string falls through to global", snap: snapWithBoth("t1", "k", true, "t"), tenant: "t1", key: "k", fallback: false, want: true}, + {name: "tenant int coercion zero", snap: snapWithTenant("t1", "k", 0), tenant: "t1", key: "k", fallback: true, want: false}, + {name: "tenant invalid int falls through to global", snap: snapWithBoth("t1", "k", true, 2), tenant: "t1", key: "k", fallback: false, want: true}, + {name: "tenant nil falls through to global", snap: snapWithBoth("t1", "k", true, nil), tenant: "t1", key: "k", fallback: false, want: true}, + {name: "tenant malformed falls through to global", snap: snapWithBoth("t1", "k", true, []string{"bad"}), tenant: "t1", key: "k", fallback: false, want: true}, + {name: "global malformed falls through to fallback", snap: snapWithGlobal("k", []string{"bad"}), tenant: "t1", key: "k", fallback: false, want: false}, + {name: "tenant false overrides global true", snap: snapWithBoth("t1", "k", true, false), tenant: "t1", key: "k", fallback: true, want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := SnapSettingBool(tt.snap, tt.tenant, tt.key, tt.fallback) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/commons/systemplane/domain/snapshot_test_helpers_test.go b/commons/systemplane/domain/snapshot_test_helpers_test.go new file mode 100644 index 00000000..a247451f --- /dev/null +++ b/commons/systemplane/domain/snapshot_test_helpers_test.go @@ -0,0 +1,54 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package domain + +type stringer struct{ s string } + +func (s stringer) String() string { return s.s } + +type panicStringer struct{} + +func (*panicStringer) String() string { + panic("typed-nil stringer should not be invoked") +} + +func snapWith(key string, val any) *Snapshot { + return &Snapshot{ + Configs: map[string]EffectiveValue{ + key: {Key: key, Value: val}, + }, + } +} + +func snapWithGlobal(key string, val any) *Snapshot { + return &Snapshot{ + GlobalSettings: map[string]EffectiveValue{ + key: {Key: key, Value: val}, + }, + } +} + +func snapWithTenant(tenantID, key string, val any) *Snapshot { + return &Snapshot{ + TenantSettings: map[string]map[string]EffectiveValue{ + tenantID: { + key: {Key: key, Value: val}, + }, + }, + } +} + +func snapWithBoth(tenantID, key string, globalVal, tenantVal any) *Snapshot { + return &Snapshot{ + GlobalSettings: map[string]EffectiveValue{ + key: {Key: key, Value: globalVal}, + }, + TenantSettings: map[string]map[string]EffectiveValue{ + tenantID: { + key: {Key: key, Value: tenantVal}, + }, + }, + } +} diff --git a/commons/systemplane/ports/authorizer_defaults.go b/commons/systemplane/ports/authorizer_defaults.go new file mode 100644 index 00000000..b184661c --- /dev/null +++ b/commons/systemplane/ports/authorizer_defaults.go @@ -0,0 +1,96 @@ +// Copyright 2025 Lerian Studio. + +package ports + +import ( + "context" + "strings" + + "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" +) + +// AllowAllAuthorizer permits authorization when context is non-nil and +// permission is non-empty after trimming spaces. It fails closed with +// [domain.ErrPermissionDenied] for nil receiver, nil context, or blank +// permission. Use only when authentication is explicitly disabled. +type AllowAllAuthorizer struct{} + +// Compile-time interface check. +var _ Authorizer = (*AllowAllAuthorizer)(nil) + +func (a *AllowAllAuthorizer) Authorize(ctx context.Context, permission string) error { + if domain.IsNilValue(a) { + return domain.ErrPermissionDenied + } + + if ctx == nil || strings.TrimSpace(permission) == "" { + return domain.ErrPermissionDenied + } + + return nil +} + +// PermissionCheckerFunc is a callback that checks whether the current actor +// has access to the given resource and action. Implementations typically +// delegate to an external auth service (e.g., lib-auth). +type PermissionCheckerFunc func(ctx context.Context, resource, action string) error + +// DelegatingAuthorizer splits permission strings and forwards each check +// to an external auth service via the CheckPermission callback. +// +// Permission format: "resourceaction" (e.g., "system/configs:read"). +// Default separator is ":". +type DelegatingAuthorizer struct { + // CheckPermission delegates to the external auth service. + // MUST NOT be nil — if nil, Authorize returns ErrPermissionDenied (fail-closed). + CheckPermission PermissionCheckerFunc + + // Separator splits the permission string into resource and action. + // Defaults to ":" if empty. + Separator string +} + +// Compile-time interface check. +var _ Authorizer = (*DelegatingAuthorizer)(nil) + +func (a *DelegatingAuthorizer) Authorize(ctx context.Context, permission string) error { + if domain.IsNilValue(a) || a.CheckPermission == nil { + return domain.ErrPermissionDenied + } + + if ctx == nil { + return domain.ErrPermissionDenied + } + + sep := a.Separator + if sep == "" { + sep = ":" + } + + permission = strings.TrimSpace(permission) + if permission == "" { + return domain.ErrPermissionDenied + } + + resource, action := splitPermission(permission, sep) + resource = strings.TrimSpace(resource) + + action = strings.TrimSpace(action) + if resource == "" || action == "" { + return domain.ErrPermissionDenied + } + + return a.CheckPermission(ctx, resource, action) +} + +// splitPermission splits a permission string by the last occurrence of sep. +// "system/configs:read" with sep ":" -> ("system/configs", "read") +// "admin" with no sep found -> ("admin", "") +func splitPermission(permission, sep string) (resource, action string) { + idx := strings.LastIndex(permission, sep) + if idx < 0 { + return permission, "" + } + + return permission[:idx], permission[idx+len(sep):] +} diff --git a/commons/systemplane/ports/authorizer_defaults_test.go b/commons/systemplane/ports/authorizer_defaults_test.go new file mode 100644 index 00000000..bfe4a68d --- /dev/null +++ b/commons/systemplane/ports/authorizer_defaults_test.go @@ -0,0 +1,333 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package ports + +import ( + "context" + "errors" + "testing" + + "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAllowAllAuthorizer_AllowsNonEmptyPermissionWithContext(t *testing.T) { + t.Parallel() + + auth := &AllowAllAuthorizer{} + + perms := []string{ + "system/configs:read", + "admin:delete", + "anything-at-all", + } + + for _, perm := range perms { + err := auth.Authorize(context.Background(), perm) + assert.NoError(t, err, "permission %q should be allowed", perm) + } +} + +func TestAllowAllAuthorizer_FailsClosed(t *testing.T) { + t.Parallel() + + auth := &AllowAllAuthorizer{} + + assert.ErrorIs(t, auth.Authorize(nil, "system/configs:read"), domain.ErrPermissionDenied) + assert.ErrorIs(t, auth.Authorize(context.Background(), ""), domain.ErrPermissionDenied) + assert.ErrorIs(t, auth.Authorize(context.Background(), " "), domain.ErrPermissionDenied) +} + +func TestAllowAllAuthorizer_TypedNilReceiver_FailsClosed(t *testing.T) { + t.Parallel() + + var auth *AllowAllAuthorizer + + assert.ErrorIs(t, auth.Authorize(context.Background(), "system/configs:read"), domain.ErrPermissionDenied) +} + +func TestDelegatingAuthorizer_SplitsCorrectly(t *testing.T) { + t.Parallel() + + var gotResource, gotAction string + + auth := &DelegatingAuthorizer{ + CheckPermission: func(_ context.Context, resource, action string) error { + gotResource = resource + gotAction = action + + return nil + }, + } + + err := auth.Authorize(context.Background(), "system/configs:read") + + require.NoError(t, err) + assert.Equal(t, "system/configs", gotResource) + assert.Equal(t, "read", gotAction) +} + +func TestDelegatingAuthorizer_CustomSeparator(t *testing.T) { + t.Parallel() + + var gotResource, gotAction string + + auth := &DelegatingAuthorizer{ + CheckPermission: func(_ context.Context, resource, action string) error { + gotResource = resource + gotAction = action + + return nil + }, + Separator: ".", + } + + err := auth.Authorize(context.Background(), "system.write") + + require.NoError(t, err) + assert.Equal(t, "system", gotResource) + assert.Equal(t, "write", gotAction) +} + +func TestDelegatingAuthorizer_NoSeparatorInPermission_Denied(t *testing.T) { + t.Parallel() + + auth := &DelegatingAuthorizer{ + CheckPermission: func(_ context.Context, _, _ string) error { + return nil + }, + } + + err := auth.Authorize(context.Background(), "admin") + + require.ErrorIs(t, err, domain.ErrPermissionDenied) +} + +func TestDelegatingAuthorizer_WhitespaceTrimmed(t *testing.T) { + t.Parallel() + + var gotResource, gotAction string + + auth := &DelegatingAuthorizer{ + CheckPermission: func(_ context.Context, resource, action string) error { + gotResource = resource + gotAction = action + + return nil + }, + } + + err := auth.Authorize(context.Background(), " system/configs : read ") + + require.NoError(t, err) + assert.Equal(t, "system/configs", gotResource) + assert.Equal(t, "read", gotAction) +} + +func TestDelegatingAuthorizer_ExplicitlyMalformedPermission_FailsClosed(t *testing.T) { + t.Parallel() + + auth := &DelegatingAuthorizer{ + CheckPermission: func(_ context.Context, _, _ string) error { + return nil + }, + } + + assert.ErrorIs(t, auth.Authorize(context.Background(), ":"), domain.ErrPermissionDenied) + assert.ErrorIs(t, auth.Authorize(context.Background(), "configs:"), domain.ErrPermissionDenied) + assert.ErrorIs(t, auth.Authorize(context.Background(), ":read"), domain.ErrPermissionDenied) + assert.ErrorIs(t, auth.Authorize(context.Background(), " "), domain.ErrPermissionDenied) + assert.ErrorIs(t, auth.Authorize(context.Background(), "configs: "), domain.ErrPermissionDenied) + assert.ErrorIs(t, auth.Authorize(context.Background(), " :read"), domain.ErrPermissionDenied) +} + +func TestDelegatingAuthorizer_TypedNilReceiver_FailsClosed(t *testing.T) { + t.Parallel() + + var auth *DelegatingAuthorizer + err := auth.Authorize(context.Background(), "configs:read") + assert.ErrorIs(t, err, domain.ErrPermissionDenied) +} + +func TestDelegatingAuthorizer_NilChecker_FailsClosed(t *testing.T) { + t.Parallel() + + auth := &DelegatingAuthorizer{ + CheckPermission: nil, + } + + err := auth.Authorize(context.Background(), "anything:read") + + require.ErrorIs(t, err, domain.ErrPermissionDenied) +} + +func TestDelegatingAuthorizer_NilContext_FailsClosed(t *testing.T) { + t.Parallel() + + auth := &DelegatingAuthorizer{ + CheckPermission: func(_ context.Context, _, _ string) error { + return nil + }, + } + + err := auth.Authorize(nil, "anything:read") + + require.ErrorIs(t, err, domain.ErrPermissionDenied) +} + +func TestDelegatingAuthorizer_PropagatesError(t *testing.T) { + t.Parallel() + + customErr := errors.New("external auth service unavailable") + + auth := &DelegatingAuthorizer{ + CheckPermission: func(_ context.Context, _, _ string) error { + return customErr + }, + } + + err := auth.Authorize(context.Background(), "system/configs:read") + + require.ErrorIs(t, err, customErr) +} + +func TestDelegatingAuthorizer_CheckerSuccess(t *testing.T) { + t.Parallel() + + auth := &DelegatingAuthorizer{ + CheckPermission: func(_ context.Context, _, _ string) error { + return nil + }, + } + + err := auth.Authorize(context.Background(), "system/configs:write") + + require.NoError(t, err) +} + +func TestDelegatingAuthorizer_DefaultSeparator(t *testing.T) { + t.Parallel() + + var gotResource, gotAction string + + auth := &DelegatingAuthorizer{ + CheckPermission: func(_ context.Context, resource, action string) error { + gotResource = resource + gotAction = action + + return nil + }, + Separator: "", // should default to ":" + } + + err := auth.Authorize(context.Background(), "configs:delete") + + require.NoError(t, err) + assert.Equal(t, "configs", gotResource) + assert.Equal(t, "delete", gotAction) +} + +func TestDelegatingAuthorizer_MultipleSeparators(t *testing.T) { + t.Parallel() + + var gotResource, gotAction string + + auth := &DelegatingAuthorizer{ + CheckPermission: func(_ context.Context, resource, action string) error { + gotResource = resource + gotAction = action + + return nil + }, + } + + // Contains two ":"; should split on the LAST one. + err := auth.Authorize(context.Background(), "system/configs/schema:read") + + require.NoError(t, err) + assert.Equal(t, "system/configs/schema", gotResource) + assert.Equal(t, "read", gotAction) +} + +func TestSplitPermission(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + permission string + sep string + wantResource string + wantAction string + }{ + { + name: "standard colon split", + permission: "configs:read", + sep: ":", + wantResource: "configs", + wantAction: "read", + }, + { + name: "nested resource with colon", + permission: "system/configs:write", + sep: ":", + wantResource: "system/configs", + wantAction: "write", + }, + { + name: "multiple colons splits on last", + permission: "a:b:c", + sep: ":", + wantResource: "a:b", + wantAction: "c", + }, + { + name: "no separator found", + permission: "admin", + sep: ":", + wantResource: "admin", + wantAction: "", + }, + { + name: "dot separator", + permission: "system.delete", + sep: ".", + wantResource: "system", + wantAction: "delete", + }, + { + name: "multi-char separator", + permission: "system::read", + sep: "::", + wantResource: "system", + wantAction: "read", + }, + { + name: "empty permission", + permission: "", + sep: ":", + wantResource: "", + wantAction: "", + }, + { + name: "separator only", + permission: ":", + sep: ":", + wantResource: "", + wantAction: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resource, action := splitPermission(tt.permission, tt.sep) + + assert.Equal(t, tt.wantResource, resource) + assert.Equal(t, tt.wantAction, action) + }) + } +} diff --git a/commons/systemplane/ports/identity_defaults.go b/commons/systemplane/ports/identity_defaults.go new file mode 100644 index 00000000..2cfe264c --- /dev/null +++ b/commons/systemplane/ports/identity_defaults.go @@ -0,0 +1,90 @@ +// Copyright 2025 Lerian Studio. + +package ports + +import ( + "context" + "strings" + + "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" +) + +// FuncIdentityResolver adapts simple extraction functions to the IdentityResolver +// interface. Products wire their existing context-extraction logic (e.g., +// auth.GetUserID, auth.GetTenantID) without writing a full struct. +type FuncIdentityResolver struct { + // ActorFunc extracts the actor ID from context. + // If it returns "", DefaultActor is used when configured; otherwise Actor fails closed. + ActorFunc func(ctx context.Context) string + + // TenantFunc extracts the tenant ID from context. + // It must return a non-empty tenant ID; otherwise TenantID fails closed. + TenantFunc func(ctx context.Context) string + + // DefaultActor is an explicit fallback actor ID used when ActorFunc is nil or returns "". + // If empty, Actor fails closed. + DefaultActor string +} + +// maxActorIDLength bounds the length of an actor ID to prevent abuse. +const maxActorIDLength = 256 + +// maxTenantIDLength bounds the length of a tenant ID to prevent abuse. +const maxTenantIDLength = 256 + +// Compile-time interface check. +var _ IdentityResolver = (*FuncIdentityResolver)(nil) + +// Actor resolves the actor identity from ctx. It fails closed with +// domain.ErrPermissionDenied on nil receiver, nil context, empty/whitespace-only +// IDs, or IDs exceeding maxActorIDLength. When ActorFunc is nil or returns an +// empty or whitespace-only string, DefaultActor is used as a fallback. +func (r *FuncIdentityResolver) Actor(ctx context.Context) (domain.Actor, error) { + if domain.IsNilValue(r) { + return domain.Actor{}, domain.ErrPermissionDenied + } + + if ctx == nil { + return domain.Actor{}, domain.ErrPermissionDenied + } + + fallback := strings.TrimSpace(r.DefaultActor) + + id := "" + if r.ActorFunc != nil { + id = strings.TrimSpace(r.ActorFunc(ctx)) + } + + if id == "" { + id = fallback + if id == "" { + return domain.Actor{}, domain.ErrPermissionDenied + } + } + + if len(id) > maxActorIDLength { + return domain.Actor{}, domain.ErrPermissionDenied + } + + return domain.Actor{ID: id}, nil +} + +// TenantID resolves the tenant identity from ctx. It fails closed with +// domain.ErrPermissionDenied on nil receiver, nil context, nil TenantFunc, +// empty/whitespace-only IDs, or IDs exceeding maxTenantIDLength. +func (r *FuncIdentityResolver) TenantID(ctx context.Context) (string, error) { + if domain.IsNilValue(r) || r.TenantFunc == nil { + return "", domain.ErrPermissionDenied + } + + if ctx == nil { + return "", domain.ErrPermissionDenied + } + + tenantID := strings.TrimSpace(r.TenantFunc(ctx)) + if tenantID == "" || len(tenantID) > maxTenantIDLength { + return "", domain.ErrPermissionDenied + } + + return tenantID, nil +} diff --git a/commons/systemplane/ports/identity_defaults_test.go b/commons/systemplane/ports/identity_defaults_test.go new file mode 100644 index 00000000..39c30175 --- /dev/null +++ b/commons/systemplane/ports/identity_defaults_test.go @@ -0,0 +1,227 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package ports + +import ( + "context" + "strings" + "testing" + + "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFuncIdentityResolver_HappyPath(t *testing.T) { + t.Parallel() + + resolver := &FuncIdentityResolver{ + ActorFunc: func(_ context.Context) string { + return "user-42" + }, + TenantFunc: func(_ context.Context) string { + return "tenant-abc" + }, + } + + actor, err := resolver.Actor(context.Background()) + require.NoError(t, err) + assert.Equal(t, "user-42", actor.ID) + + tenant, err := resolver.TenantID(context.Background()) + require.NoError(t, err) + assert.Equal(t, "tenant-abc", tenant) +} + +func TestFuncIdentityResolver_UsesExplicitDefaultActor(t *testing.T) { + t.Parallel() + + resolver := &FuncIdentityResolver{ + ActorFunc: nil, + DefaultActor: "system-bot", + } + + actor, err := resolver.Actor(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "system-bot", actor.ID) +} + +func TestFuncIdentityResolver_ActorFuncEmpty_UsesExplicitDefaultActor(t *testing.T) { + t.Parallel() + + resolver := &FuncIdentityResolver{ + ActorFunc: func(_ context.Context) string { + return " " + }, + DefaultActor: "service-account", + } + + actor, err := resolver.Actor(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "service-account", actor.ID) +} + +func TestFuncIdentityResolver_ActorID_ExceedsMaxLength_FailsClosed(t *testing.T) { + t.Parallel() + + longID := strings.Repeat("a", maxActorIDLength+1) + + resolver := &FuncIdentityResolver{ + ActorFunc: func(_ context.Context) string { + return longID + }, + } + + _, err := resolver.Actor(context.Background()) + require.ErrorIs(t, err, domain.ErrPermissionDenied) +} + +func TestFuncIdentityResolver_ActorID_ExactMaxLength_Allowed(t *testing.T) { + t.Parallel() + + exactID := strings.Repeat("a", maxActorIDLength) + + resolver := &FuncIdentityResolver{ + ActorFunc: func(_ context.Context) string { + return exactID + }, + } + + actor, err := resolver.Actor(context.Background()) + require.NoError(t, err) + assert.Equal(t, exactID, actor.ID) +} + +func TestFuncIdentityResolver_FailsClosedWithoutIdentity(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + resolver *FuncIdentityResolver + ctx context.Context + }{ + { + name: "typed nil receiver", + resolver: nil, + ctx: context.Background(), + }, + { + name: "nil context actor", + resolver: &FuncIdentityResolver{ + ActorFunc: func(_ context.Context) string { return "user-42" }, + }, + ctx: nil, + }, + { + name: "actor missing without fallback", + resolver: &FuncIdentityResolver{ + ActorFunc: nil, + }, + ctx: context.Background(), + }, + { + name: "actor empty without fallback", + resolver: &FuncIdentityResolver{ + ActorFunc: func(_ context.Context) string { return "" }, + }, + ctx: context.Background(), + }, + { + name: "whitespace-only default actor", + resolver: &FuncIdentityResolver{ + ActorFunc: nil, + DefaultActor: " ", + }, + ctx: context.Background(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, err := tt.resolver.Actor(tt.ctx) + require.ErrorIs(t, err, domain.ErrPermissionDenied) + }) + } +} + +func TestFuncIdentityResolver_TenantID_FailsClosedWithoutTenantIdentity(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + resolver *FuncIdentityResolver + ctx context.Context + }{ + { + name: "typed nil receiver", + resolver: nil, + ctx: context.Background(), + }, + { + name: "nil tenant func", + resolver: &FuncIdentityResolver{ + TenantFunc: nil, + }, + ctx: context.Background(), + }, + { + name: "nil context", + resolver: &FuncIdentityResolver{ + TenantFunc: func(_ context.Context) string { return "tenant-1" }, + }, + ctx: nil, + }, + { + name: "empty tenant", + resolver: &FuncIdentityResolver{ + TenantFunc: func(_ context.Context) string { return "" }, + }, + ctx: context.Background(), + }, + { + name: "whitespace tenant", + resolver: &FuncIdentityResolver{ + TenantFunc: func(_ context.Context) string { return " " }, + }, + ctx: context.Background(), + }, + { + name: "tenant exceeds max length", + resolver: &FuncIdentityResolver{ + TenantFunc: func(_ context.Context) string { return strings.Repeat("t", maxTenantIDLength+1) }, + }, + ctx: context.Background(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, err := tt.resolver.TenantID(tt.ctx) + require.ErrorIs(t, err, domain.ErrPermissionDenied) + }) + } +} + +func TestFuncIdentityResolver_TenantID_ExactMaxLength_Allowed(t *testing.T) { + t.Parallel() + + exactID := strings.Repeat("t", maxTenantIDLength) + + resolver := &FuncIdentityResolver{ + TenantFunc: func(_ context.Context) string { + return exactID + }, + } + + tenant, err := resolver.TenantID(context.Background()) + require.NoError(t, err) + assert.Equal(t, exactID, tenant) +} diff --git a/commons/systemplane/service/component_diff.go b/commons/systemplane/service/component_diff.go new file mode 100644 index 00000000..11d53ee5 --- /dev/null +++ b/commons/systemplane/service/component_diff.go @@ -0,0 +1,215 @@ +// Copyright 2025 Lerian Studio. + +package service + +import ( + "reflect" + "sort" + + "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" +) + +// ComponentDiff computes which infrastructure components need rebuilding +// by comparing effective values between two snapshots. +type ComponentDiff struct { + allComponents []string + keyMeta map[string]componentKeyMeta +} + +type componentKeyMeta struct { + component string + triggersBundleRebuild bool + invalidApplyBehavior bool + unclassified bool +} + +// NewComponentDiff builds a diff engine from key definitions. +// Keys with rebuild-triggering apply behaviors and empty Component are tracked +// as unclassified so ChangedComponents can force a full rebuild for safety. +// Invalid apply behaviors are also treated conservatively when a changed key is +// encountered. +func NewComponentDiff(defs []domain.KeyDef) *ComponentDiff { + keyMeta := make(map[string]componentKeyMeta, len(defs)) + componentSet := make(map[string]struct{}, len(defs)) + + for _, def := range defs { + meta := componentKeyMeta{ + component: def.Component, + triggersBundleRebuild: triggersBundleRebuild(def.ApplyBehavior), + invalidApplyBehavior: !def.ApplyBehavior.IsValid(), + unclassified: triggersBundleRebuild(def.ApplyBehavior) && def.Component == "", + } + keyMeta[def.Key] = meta + + if def.Component != "" && def.Component != domain.ComponentNone { + componentSet[def.Component] = struct{}{} + } + } + + components := make([]string, 0, len(componentSet)) + for component := range componentSet { + components = append(components, component) + } + + sort.Strings(components) + + return &ComponentDiff{ + allComponents: components, + keyMeta: keyMeta, + } +} + +// triggersBundleRebuild reports whether the given apply behavior requires a +// component rebuild. Only bundle-rebuild and bundle-rebuild+worker-reconcile +// qualify; all other behaviors propagate without infrastructure teardown. +func triggersBundleRebuild(ab domain.ApplyBehavior) bool { + switch ab { + case domain.ApplyBundleRebuild, domain.ApplyBundleRebuildAndReconcile: + return true + default: + return false + } +} + +// ChangedComponents returns the set of component names that have at least +// one key whose effective value differs between prev and current snapshots. +// Configs, global settings, and tenant settings are all considered. +// If prev is zero-value, ALL components are returned (full rebuild). +// Unknown keys, invalid apply behaviors, or unclassified rebuild-triggering +// keys also force a full rebuild for safety. +func (d *ComponentDiff) ChangedComponents(prev, current domain.Snapshot) map[string]bool { + if d == nil || len(d.allComponents) == 0 { + return map[string]bool{} + } + + if snapshotIsZeroValue(prev) { + return d.allComponentSet() + } + + changed := make(map[string]bool) + if d.markChangedComponents(prev.Configs, current.Configs, changed) { + return d.allComponentSet() + } + + if d.markChangedComponents(prev.GlobalSettings, current.GlobalSettings, changed) { + return d.allComponentSet() + } + + if d.markChangedTenantSettings(prev.TenantSettings, current.TenantSettings, changed) { + return d.allComponentSet() + } + + return changed +} + +// AllComponents returns every distinct component name in the mapping, sorted. +func (d *ComponentDiff) AllComponents() []string { + if d == nil || len(d.allComponents) == 0 { + return []string{} + } + + return append([]string(nil), d.allComponents...) +} + +func snapshotIsZeroValue(snapshot domain.Snapshot) bool { + return snapshot.Revision == domain.RevisionZero && + snapshot.BuiltAt.IsZero() && + len(snapshot.Configs) == 0 && + len(snapshot.GlobalSettings) == 0 && + len(snapshot.TenantSettings) == 0 +} + +func (d *ComponentDiff) markChangedTenantSettings(prev, current map[string]map[string]domain.EffectiveValue, changed map[string]bool) bool { + tenantIDs := make(map[string]struct{}, len(prev)+len(current)) + for tenantID := range prev { + tenantIDs[tenantID] = struct{}{} + } + + for tenantID := range current { + tenantIDs[tenantID] = struct{}{} + } + + for tenantID := range tenantIDs { + if d.markChangedComponents(prev[tenantID], current[tenantID], changed) { + return true + } + } + + return false +} + +func (d *ComponentDiff) markChangedComponents(prevEntries, currentEntries map[string]domain.EffectiveValue, changed map[string]bool) bool { + keys := make(map[string]struct{}, len(prevEntries)+len(currentEntries)) + for key := range prevEntries { + keys[key] = struct{}{} + } + + for key := range currentEntries { + keys[key] = struct{}{} + } + + for key := range keys { + prevVal, prevOK := prevEntries[key] + + curVal, curOK := currentEntries[key] + if !effectiveValueChanged(prevVal, prevOK, curVal, curOK) { + continue + } + + meta, found := d.keyMeta[key] + if !found { + return true + } + + if meta.invalidApplyBehavior { + return true + } + + if !meta.triggersBundleRebuild { + continue + } + + if meta.unclassified { + return true + } + + if meta.component == domain.ComponentNone { + continue + } + + changed[meta.component] = true + } + + return false +} + +// allComponentSet returns a bool map of every component — used for full +// rebuild when there is no previous snapshot to diff against. +func (d *ComponentDiff) allComponentSet() map[string]bool { + set := make(map[string]bool, len(d.allComponents)) + + for _, component := range d.allComponents { + set[component] = true + } + + return set +} + +// effectiveValueChanged reports whether the effective value changed between two +// snapshots. It uses reflect.DeepEqual on EffectiveValue.Value, which means: +// - nil slice != empty slice (both are valid "no items" representations) +// - function values are always unequal +// - supported types: primitives, string slices, maps — the types produced by +// the coercion layer. If other types are stored, callers must normalize +// before comparison. +func effectiveValueChanged(prevVal domain.EffectiveValue, prevOK bool, curVal domain.EffectiveValue, curOK bool) bool { + if prevOK != curOK { + return true + } + + if !prevOK { + return false + } + + return !reflect.DeepEqual(prevVal.Value, curVal.Value) +} diff --git a/commons/systemplane/service/component_diff_test.go b/commons/systemplane/service/component_diff_test.go new file mode 100644 index 00000000..ca90ddee --- /dev/null +++ b/commons/systemplane/service/component_diff_test.go @@ -0,0 +1,404 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package service + +import ( + "testing" + + "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + "github.com/stretchr/testify/assert" +) + +func TestNewComponentDiff(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + defs []domain.KeyDef + wantComponents []string + }{ + { + name: "empty defs produce no components", + defs: nil, + wantComponents: []string{}, + }, + { + name: "all ComponentNone keys are excluded", + defs: []domain.KeyDef{ + {Key: "k1", Component: domain.ComponentNone, ApplyBehavior: domain.ApplyBundleRebuild}, + {Key: "k2", Component: domain.ComponentNone, ApplyBehavior: domain.ApplyBundleRebuildAndReconcile}, + }, + wantComponents: []string{}, + }, + { + name: "empty component string does not become a concrete component", + defs: []domain.KeyDef{ + {Key: "k1", Component: "", ApplyBehavior: domain.ApplyBundleRebuild}, + }, + wantComponents: []string{}, + }, + { + name: "bootstrap-only keys still contribute known components", + defs: []domain.KeyDef{ + {Key: "k1", Component: "postgres", ApplyBehavior: domain.ApplyBootstrapOnly}, + }, + wantComponents: []string{"postgres"}, + }, + { + name: "live-read keys still contribute known components", + defs: []domain.KeyDef{ + {Key: "k1", Component: "redis", ApplyBehavior: domain.ApplyLiveRead}, + }, + wantComponents: []string{"redis"}, + }, + { + name: "worker-reconcile keys still contribute known components", + defs: []domain.KeyDef{ + {Key: "k1", Component: "rabbitmq", ApplyBehavior: domain.ApplyWorkerReconcile}, + }, + wantComponents: []string{"rabbitmq"}, + }, + { + name: "bundle-rebuild keys are included", + defs: []domain.KeyDef{ + {Key: "k1", Component: "postgres", ApplyBehavior: domain.ApplyBundleRebuild}, + }, + wantComponents: []string{"postgres"}, + }, + { + name: "bundle-rebuild-and-reconcile keys are included", + defs: []domain.KeyDef{ + {Key: "k1", Component: "redis", ApplyBehavior: domain.ApplyBundleRebuildAndReconcile}, + }, + wantComponents: []string{"redis"}, + }, + { + name: "multiple components sorted", + defs: []domain.KeyDef{ + {Key: "k1", Component: "redis", ApplyBehavior: domain.ApplyBundleRebuild}, + {Key: "k2", Component: "postgres", ApplyBehavior: domain.ApplyBundleRebuild}, + {Key: "k3", Component: "s3", ApplyBehavior: domain.ApplyBundleRebuildAndReconcile}, + }, + wantComponents: []string{"postgres", "redis", "s3"}, + }, + { + name: "duplicate component from multiple keys is deduplicated", + defs: []domain.KeyDef{ + {Key: "pg.host", Component: "postgres", ApplyBehavior: domain.ApplyBundleRebuild}, + {Key: "pg.port", Component: "postgres", ApplyBehavior: domain.ApplyBundleRebuild}, + }, + wantComponents: []string{"postgres"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + diff := NewComponentDiff(tt.defs) + assert.Equal(t, tt.wantComponents, diff.AllComponents()) + }) + } +} + +func TestComponentDiff_ChangedComponents(t *testing.T) { + t.Parallel() + + // Shared defs used by most sub-tests: two postgres keys + one redis key. + baseDefs := []domain.KeyDef{ + {Key: "pg.host", Component: "postgres", ApplyBehavior: domain.ApplyBundleRebuild}, + {Key: "pg.port", Component: "postgres", ApplyBehavior: domain.ApplyBundleRebuild}, + {Key: "redis.url", Component: "redis", ApplyBehavior: domain.ApplyBundleRebuild}, + } + + tests := []struct { + name string + defs []domain.KeyDef + prev domain.Snapshot + current domain.Snapshot + want map[string]bool + }{ + { + name: "no defs means no changed components", + defs: nil, + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "localhost"}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "remote"}, + }}, + want: map[string]bool{}, + }, + { + name: "nil prev configs triggers full rebuild", + defs: baseDefs, + prev: domain.Snapshot{}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "localhost"}, + "redis.url": {Value: "redis://host"}, + }}, + want: map[string]bool{"postgres": true, "redis": true}, + }, + { + name: "empty prev configs triggers full rebuild", + defs: baseDefs, + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{}}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "localhost"}, + }}, + want: map[string]bool{"postgres": true, "redis": true}, + }, + { + name: "initial boot includes bootstrap-only components", + defs: append(baseDefs, domain.KeyDef{Key: "server.address", Component: "http", ApplyBehavior: domain.ApplyBootstrapOnly}), + prev: domain.Snapshot{}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "localhost"}, + "redis.url": {Value: "redis://host"}, + "server.address": {Value: ":8080"}, + }}, + want: map[string]bool{"http": true, "postgres": true, "redis": true}, + }, + { + name: "single component with one changed key", + defs: baseDefs, + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "old-host"}, + "pg.port": {Value: 5432}, + "redis.url": {Value: "redis://same"}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "new-host"}, + "pg.port": {Value: 5432}, + "redis.url": {Value: "redis://same"}, + }}, + want: map[string]bool{"postgres": true}, + }, + { + name: "no changes returns empty set", + defs: baseDefs, + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "pg.port": {Value: 5432}, + "redis.url": {Value: "redis://same"}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "pg.port": {Value: 5432}, + "redis.url": {Value: "redis://same"}, + }}, + want: map[string]bool{}, + }, + { + name: "multiple components changed in one diff", + defs: baseDefs, + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "old-host"}, + "pg.port": {Value: 5432}, + "redis.url": {Value: "redis://old"}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "new-host"}, + "pg.port": {Value: 5432}, + "redis.url": {Value: "redis://new"}, + }}, + want: map[string]bool{"postgres": true, "redis": true}, + }, + { + name: "global setting change triggers component rebuild", + defs: append(baseDefs, domain.KeyDef{Key: "fees.fail_closed_default", Component: "fees", ApplyBehavior: domain.ApplyBundleRebuild, Kind: domain.KindSetting}), + prev: domain.Snapshot{GlobalSettings: map[string]domain.EffectiveValue{ + "fees.fail_closed_default": {Value: false}, + }}, + current: domain.Snapshot{GlobalSettings: map[string]domain.EffectiveValue{ + "fees.fail_closed_default": {Value: true}, + }}, + want: map[string]bool{"fees": true}, + }, + { + name: "tenant setting change triggers component rebuild", + defs: append(baseDefs, domain.KeyDef{Key: "fees.max_fee_amount_cents", Component: "fees", ApplyBehavior: domain.ApplyBundleRebuild, Kind: domain.KindSetting}), + prev: domain.Snapshot{TenantSettings: map[string]map[string]domain.EffectiveValue{ + "tenant-a": { + "fees.max_fee_amount_cents": {Value: 100}, + }, + }}, + current: domain.Snapshot{TenantSettings: map[string]map[string]domain.EffectiveValue{ + "tenant-a": { + "fees.max_fee_amount_cents": {Value: 200}, + }, + }}, + want: map[string]bool{"fees": true}, + }, + { + name: "key exists in current but not prev marks component changed", + defs: baseDefs, + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "host"}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "host"}, + "redis.url": {Value: "redis://new"}, + }}, + want: map[string]bool{"redis": true}, + }, + { + name: "key exists in prev but not current marks component changed", + defs: baseDefs, + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "host"}, + "redis.url": {Value: "redis://old"}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "host"}, + }}, + want: map[string]bool{"redis": true}, + }, + { + name: "both keys absent for a component means no change", + defs: baseDefs, + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "host"}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "host"}, + }}, + want: map[string]bool{}, + }, + { + name: "deep equality catches structural changes", + defs: baseDefs, + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: []string{"host-a", "host-b"}}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: []string{"host-a", "host-c"}}, + }}, + want: map[string]bool{"postgres": true}, + }, + { + name: "unknown changed key forces full rebuild for safety", + defs: baseDefs, + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "unknown.k": {Value: "old"}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "unknown.k": {Value: "new"}, + }}, + want: map[string]bool{"postgres": true, "redis": true}, + }, + { + name: "known live-read change does not rebuild components", + defs: append(baseDefs, domain.KeyDef{Key: "app.log_level", Component: domain.ComponentNone, ApplyBehavior: domain.ApplyLiveRead}), + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "app.log_level": {Value: "info"}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "app.log_level": {Value: "debug"}, + }}, + want: map[string]bool{}, + }, + { + name: "known bootstrap-only change does not rebuild components after boot", + defs: append(baseDefs, domain.KeyDef{Key: "server.address", Component: "http", ApplyBehavior: domain.ApplyBootstrapOnly}), + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "server.address": {Value: ":8080"}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "server.address": {Value: ":8081"}, + }}, + want: map[string]bool{}, + }, + { + name: "known worker-reconcile change does not rebuild components", + defs: append(baseDefs, domain.KeyDef{Key: "worker.interval", Component: "worker", ApplyBehavior: domain.ApplyWorkerReconcile}), + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "worker.interval": {Value: 1}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "worker.interval": {Value: 2}, + }}, + want: map[string]bool{}, + }, + { + name: "rebuild key with ComponentNone stays excluded", + defs: append(baseDefs, domain.KeyDef{Key: "feature.toggle", Component: domain.ComponentNone, ApplyBehavior: domain.ApplyBundleRebuild}), + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "feature.toggle": {Value: false}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "feature.toggle": {Value: true}, + }}, + want: map[string]bool{}, + }, + { + name: "unclassified rebuild key forces full rebuild", + defs: append(baseDefs, domain.KeyDef{Key: "pg.unclassified", Component: "", ApplyBehavior: domain.ApplyBundleRebuild}), + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "pg.unclassified": {Value: "old"}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "pg.unclassified": {Value: "new"}, + }}, + want: map[string]bool{"postgres": true, "redis": true}, + }, + { + name: "invalid apply behavior forces full rebuild for safety", + defs: append(baseDefs, domain.KeyDef{Key: "pg.invalid", Component: "postgres", ApplyBehavior: domain.ApplyBehavior("invalid")}), + prev: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "pg.invalid": {Value: "old"}, + }}, + current: domain.Snapshot{Configs: map[string]domain.EffectiveValue{ + "pg.host": {Value: "same"}, + "pg.invalid": {Value: "new"}, + }}, + want: map[string]bool{"postgres": true, "redis": true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + diff := NewComponentDiff(tt.defs) + got := diff.ChangedComponents(tt.prev, tt.current) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestComponentDiff_NilReceiver(t *testing.T) { + t.Parallel() + + var diff *ComponentDiff + + assert.Equal(t, map[string]bool{}, diff.ChangedComponents(domain.Snapshot{}, domain.Snapshot{})) + assert.Equal(t, []string{}, diff.AllComponents()) +} + +func TestComponentDiff_AllComponents_Sorted(t *testing.T) { + t.Parallel() + + defs := []domain.KeyDef{ + {Key: "z.key", Component: "zebra", ApplyBehavior: domain.ApplyBundleRebuild}, + {Key: "a.key", Component: "alpha", ApplyBehavior: domain.ApplyBundleRebuild}, + {Key: "m.key", Component: "middle", ApplyBehavior: domain.ApplyBundleRebuildAndReconcile}, + } + + diff := NewComponentDiff(defs) + + assert.Equal(t, []string{"alpha", "middle", "zebra"}, diff.AllComponents()) +} diff --git a/commons/systemplane/service/manager.go b/commons/systemplane/service/manager.go index 7e4e09ec..f4fbab94 100644 --- a/commons/systemplane/service/manager.go +++ b/commons/systemplane/service/manager.go @@ -47,6 +47,7 @@ type ResolvedSet struct { // SchemaEntry describes a single key's metadata for the schema endpoint. type SchemaEntry struct { Key string + EnvVar string Kind domain.Kind AllowedScopes []domain.Scope ValueType domain.ValueType @@ -85,28 +86,36 @@ type ManagerConfig struct { StateSync func(ctx context.Context, snapshot domain.Snapshot) } -// NewManager creates a new Manager with the supplied dependencies. All -// dependencies are required; a nil dependency causes a construction-time -// error rather than a runtime panic on first use. -func NewManager(cfg ManagerConfig) (Manager, error) { +func validateManagerConfig(cfg ManagerConfig) error { if domain.IsNilValue(cfg.Registry) { - return nil, errManagerRegistryRequired + return errManagerRegistryRequired } if domain.IsNilValue(cfg.Store) { - return nil, errManagerStoreRequired + return errManagerStoreRequired } if domain.IsNilValue(cfg.History) { - return nil, errManagerHistoryRequired + return errManagerHistoryRequired } if domain.IsNilValue(cfg.Supervisor) { - return nil, errManagerSupervisorRequired + return errManagerSupervisorRequired } if cfg.Builder == nil { - return nil, errManagerBuilderRequired + return errManagerBuilderRequired + } + + return nil +} + +// NewManager creates a new Manager with the supplied dependencies. All +// dependencies are required; a nil dependency causes a construction-time +// error rather than a runtime panic on first use. +func NewManager(cfg ManagerConfig) (Manager, error) { + if err := validateManagerConfig(cfg); err != nil { + return nil, err } return &defaultManager{ diff --git a/commons/systemplane/service/manager_helpers.go b/commons/systemplane/service/manager_helpers.go index b5f5486d..7b5089e3 100644 --- a/commons/systemplane/service/manager_helpers.go +++ b/commons/systemplane/service/manager_helpers.go @@ -25,6 +25,7 @@ func buildSchema(reg registry.Registry, kind domain.Kind) []SchemaEntry { for i, def := range defs { entries[i] = SchemaEntry{ Key: def.Key, + EnvVar: def.EnvVar, Kind: def.Kind, AllowedScopes: append([]domain.Scope(nil), def.AllowedScopes...), ValueType: def.ValueType, @@ -32,7 +33,7 @@ func buildSchema(reg registry.Registry, kind domain.Kind) []SchemaEntry { MutableAtRuntime: def.MutableAtRuntime, ApplyBehavior: def.ApplyBehavior, Secret: def.Secret, - RedactPolicy: def.RedactPolicy, + RedactPolicy: effectiveRedactPolicy(def), Description: def.Description, Group: def.Group, } @@ -45,12 +46,15 @@ func cloneSnapshot(snapshot domain.Snapshot) domain.Snapshot { cloned := domain.Snapshot{ Configs: cloneEffectiveValues(snapshot.Configs), GlobalSettings: cloneEffectiveValues(snapshot.GlobalSettings), - TenantSettings: make(map[string]map[string]domain.EffectiveValue, len(snapshot.TenantSettings)), Revision: snapshot.Revision, BuiltAt: snapshot.BuiltAt, } - for tenantID, values := range snapshot.TenantSettings { - cloned.TenantSettings[tenantID] = cloneEffectiveValues(values) + + if snapshot.TenantSettings != nil { + cloned.TenantSettings = make(map[string]map[string]domain.EffectiveValue, len(snapshot.TenantSettings)) + for tenantID, values := range snapshot.TenantSettings { + cloned.TenantSettings[tenantID] = cloneEffectiveValues(values) + } } return cloned @@ -195,25 +199,34 @@ func redactValue(def domain.KeyDef, value any) any { return nil } - // Secret keys are always redacted regardless of RedactPolicy setting. - // This prevents accidental secret leaks when a developer sets Secret=true - // but forgets to set RedactPolicy explicitly. - if def.Secret { - return maskedValuePlaceholder - } - - if def.RedactPolicy == "" || def.RedactPolicy == domain.RedactNone { + policy := effectiveRedactPolicy(def) + if policy == domain.RedactNone { return value } - if def.RedactPolicy == domain.RedactMask { + if policy == domain.RedactMask { return maskRedactedValue(value) } - // Non-secret keys with an explicit full redact policy are fully hidden. return maskedValuePlaceholder } +func effectiveRedactPolicy(def domain.KeyDef) domain.RedactPolicy { + if def.Secret { + return domain.RedactFull + } + + if def.RedactPolicy == "" { + return domain.RedactNone + } + + return def.RedactPolicy +} + +func isRedacted(def domain.KeyDef) bool { + return effectiveRedactPolicy(def) != domain.RedactNone +} + func maskRedactedValue(value any) any { stringValue, ok := value.(string) if !ok { diff --git a/commons/systemplane/service/manager_helpers_test.go b/commons/systemplane/service/manager_helpers_test.go index 6e6390bb..17054f9c 100644 --- a/commons/systemplane/service/manager_helpers_test.go +++ b/commons/systemplane/service/manager_helpers_test.go @@ -52,12 +52,18 @@ func TestCloneSnapshot_DeepClonesNestedRuntimeValues(t *testing.T) { defaultMap["enabled"] = true overrideMap["kind"] = "automatic" - originalValueMap := snapshot.Configs["feature.flags"].Value.(map[string]any) - originalNestedSlice := originalValueMap["nested"].([]any) - originalNestedMap := originalNestedSlice[1].(map[string]any) - originalDefaultMap := snapshot.Configs["feature.flags"].Default.(map[string]any) - originalOverrideSlice := snapshot.Configs["feature.flags"].Override.([]any) - originalOverrideMap := originalOverrideSlice[0].(map[string]any) + originalValueMap, ok := snapshot.Configs["feature.flags"].Value.(map[string]any) + require.True(t, ok, "original Value must be map[string]any") + originalNestedSlice, ok := originalValueMap["nested"].([]any) + require.True(t, ok, "original nested must be []any") + originalNestedMap, ok := originalNestedSlice[1].(map[string]any) + require.True(t, ok, "original nested[1] must be map[string]any") + originalDefaultMap, ok := snapshot.Configs["feature.flags"].Default.(map[string]any) + require.True(t, ok, "original Default must be map[string]any") + originalOverrideSlice, ok := snapshot.Configs["feature.flags"].Override.([]any) + require.True(t, ok, "original Override must be []any") + originalOverrideMap, ok := originalOverrideSlice[0].(map[string]any) + require.True(t, ok, "original Override[0] must be map[string]any") assert.Equal(t, true, originalNestedMap["beta"]) assert.Equal(t, false, originalDefaultMap["enabled"]) @@ -79,3 +85,18 @@ func TestRedactValue_RedactMaskFallsBackForNonString(t *testing.T) { assert.Equal(t, "****", masked) } + +func TestRedactValue_SecretDefaultsToFullRedaction(t *testing.T) { + t.Parallel() + + masked := redactValue(domain.KeyDef{Secret: true}, "amqp://user:pass@host") + + assert.Equal(t, "****", masked) +} + +func TestEffectiveRedactPolicy_NormalizesZeroValue(t *testing.T) { + t.Parallel() + + assert.Equal(t, domain.RedactNone, effectiveRedactPolicy(domain.KeyDef{})) + assert.Equal(t, domain.RedactFull, effectiveRedactPolicy(domain.KeyDef{Secret: true})) +} diff --git a/commons/systemplane/service/manager_read_helpers.go b/commons/systemplane/service/manager_read_helpers.go new file mode 100644 index 00000000..f9b8d558 --- /dev/null +++ b/commons/systemplane/service/manager_read_helpers.go @@ -0,0 +1,30 @@ +// Copyright 2025 Lerian Studio. + +package service + +import "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + +func buildResolvedSet(values map[string]domain.EffectiveValue) ResolvedSet { + return ResolvedSet{ + Values: values, + Revision: revisionFromValues(values), + } +} + +func (manager *defaultManager) resolvedConfigsFromSnapshot(snapshot domain.Snapshot) (ResolvedSet, bool) { + if snapshot.BuiltAt.IsZero() || snapshot.Configs == nil { + return ResolvedSet{}, false + } + + values := redactEffectiveValues(manager.registry, cloneEffectiveValues(snapshot.Configs)) + + return buildResolvedSet(values), true +} + +func (manager *defaultManager) resolvedSettingsFromSnapshot(snapshot domain.Snapshot, subject Subject) (ResolvedSet, bool) { + if snapshot.BuiltAt.IsZero() { + return ResolvedSet{}, false + } + + return manager.cachedSettingsFromSnapshot(snapshot, subject) +} diff --git a/commons/systemplane/service/manager_reads.go b/commons/systemplane/service/manager_reads.go index dfd20b5c..97183104 100644 --- a/commons/systemplane/service/manager_reads.go +++ b/commons/systemplane/service/manager_reads.go @@ -21,24 +21,19 @@ func (manager *defaultManager) GetConfigs(ctx context.Context) (ResolvedSet, err defer span.End() snap := manager.supervisor.Snapshot() - if snap.BuiltAt.IsZero() || snap.Configs == nil { - values, revision, err := manager.builder.BuildConfigs(ctx) - if err != nil { - libOpentelemetry.HandleSpanError(span, "build configs", err) - return ResolvedSet{}, fmt.Errorf("get configs: %w", err) - } - - return ResolvedSet{ - Values: redactEffectiveValues(manager.registry, values), - Revision: revision, - }, nil + if resolved, ok := manager.resolvedConfigsFromSnapshot(snap); ok { + return resolved, nil } - values := redactEffectiveValues(manager.registry, cloneEffectiveValues(snap.Configs)) + values, revision, err := manager.builder.BuildConfigs(ctx) + if err != nil { + libOpentelemetry.HandleSpanError(span, "build configs", err) + return ResolvedSet{}, fmt.Errorf("get configs: %w", err) + } return ResolvedSet{ - Values: values, - Revision: revisionFromValues(values), + Values: redactEffectiveValues(manager.registry, values), + Revision: revision, }, nil } @@ -50,12 +45,8 @@ func (manager *defaultManager) GetSettings(ctx context.Context, subject Subject) defer span.End() snap := manager.supervisor.Snapshot() - - if !snap.BuiltAt.IsZero() { - resolved, ok := manager.cachedSettingsFromSnapshot(snap, subject) - if ok { - return resolved, nil - } + if resolved, ok := manager.resolvedSettingsFromSnapshot(snap, subject); ok { + return resolved, nil } values, revision, err := manager.builder.BuildSettings(ctx, subject) diff --git a/commons/systemplane/service/manager_reads_test.go b/commons/systemplane/service/manager_reads_test.go index 946d5f5c..7004cf48 100644 --- a/commons/systemplane/service/manager_reads_test.go +++ b/commons/systemplane/service/manager_reads_test.go @@ -361,80 +361,6 @@ func TestCachedSettingsFromSnapshot_UnknownScope_ReturnsFalse(t *testing.T) { assert.False(t, ok) } -// --------------------------------------------------------------------------- -// GetConfigSchema / GetSettingSchema -// --------------------------------------------------------------------------- - -func TestGetConfigSchema_ReturnsConfigKeysOnly(t *testing.T) { - t.Parallel() - - reg, store, history, spy, builder := testManagerDeps(t) - registerTestConfigKey(t, reg, "app.workers", domain.ApplyLiveRead, true) - registerTestSettingKey(t, reg, "ui.theme", []domain.Scope{domain.ScopeGlobal}, true) - - mgr, mgrErr := NewManager(ManagerConfig{Registry: reg, Store: store, History: history, Supervisor: spy, Builder: builder}) - require.NoError(t, mgrErr) - - entries, err := mgr.GetConfigSchema(context.Background()) - require.NoError(t, err) - require.Len(t, entries, 1) - assert.Equal(t, "app.workers", entries[0].Key) - assert.Equal(t, domain.KindConfig, entries[0].Kind) -} - -func TestGetSettingSchema_ReturnsSettingKeysOnly(t *testing.T) { - t.Parallel() - - reg, store, history, spy, builder := testManagerDeps(t) - registerTestConfigKey(t, reg, "app.workers", domain.ApplyLiveRead, true) - registerTestSettingKey(t, reg, "ui.theme", []domain.Scope{domain.ScopeGlobal, domain.ScopeTenant}, true) - - mgr, mgrErr := NewManager(ManagerConfig{Registry: reg, Store: store, History: history, Supervisor: spy, Builder: builder}) - require.NoError(t, mgrErr) - - entries, err := mgr.GetSettingSchema(context.Background()) - require.NoError(t, err) - require.Len(t, entries, 1) - assert.Equal(t, "ui.theme", entries[0].Key) - assert.Equal(t, domain.KindSetting, entries[0].Kind) -} - -func TestGetConfigSchema_RedactsSecretDefault(t *testing.T) { - t.Parallel() - - reg, store, history, spy, builder := testManagerDeps(t) - require.NoError(t, reg.Register(domain.KeyDef{ - Key: "auth.token", - Kind: domain.KindConfig, - AllowedScopes: []domain.Scope{domain.ScopeGlobal}, - ValueType: domain.ValueTypeString, - DefaultValue: "my-token-value", - Secret: true, - RedactPolicy: domain.RedactFull, - ApplyBehavior: domain.ApplyLiveRead, - MutableAtRuntime: true, - Description: "auth token", - Group: "auth", - })) - - mgr, mgrErr := NewManager(ManagerConfig{Registry: reg, Store: store, History: history, Supervisor: spy, Builder: builder}) - require.NoError(t, mgrErr) - - entries, err := mgr.GetConfigSchema(context.Background()) - require.NoError(t, err) - - var found bool - - for _, entry := range entries { - if entry.Key == "auth.token" { - found = true - assert.Equal(t, "****", entry.DefaultValue) - } - } - - assert.True(t, found, "expected auth.token in schema") -} - // --------------------------------------------------------------------------- // GetConfigHistory / GetSettingHistory // --------------------------------------------------------------------------- diff --git a/commons/systemplane/service/manager_schema_test.go b/commons/systemplane/service/manager_schema_test.go new file mode 100644 index 00000000..4da66740 --- /dev/null +++ b/commons/systemplane/service/manager_schema_test.go @@ -0,0 +1,113 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package service + +import ( + "context" + "testing" + + "github.com/LerianStudio/lib-commons/v4/commons/systemplane/domain" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetConfigSchema_ReturnsConfigKeysOnly(t *testing.T) { + t.Parallel() + + reg, store, history, spy, builder := testManagerDeps(t) + registerTestConfigKey(t, reg, "app.workers", domain.ApplyLiveRead, true) + registerTestSettingKey(t, reg, "ui.theme", []domain.Scope{domain.ScopeGlobal}, true) + + mgr, mgrErr := NewManager(ManagerConfig{Registry: reg, Store: store, History: history, Supervisor: spy, Builder: builder}) + require.NoError(t, mgrErr) + + entries, err := mgr.GetConfigSchema(context.Background()) + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, "app.workers", entries[0].Key) + assert.Equal(t, domain.KindConfig, entries[0].Kind) +} + +func TestGetSettingSchema_ReturnsSettingKeysOnly(t *testing.T) { + t.Parallel() + + reg, store, history, spy, builder := testManagerDeps(t) + registerTestConfigKey(t, reg, "app.workers", domain.ApplyLiveRead, true) + registerTestSettingKey(t, reg, "ui.theme", []domain.Scope{domain.ScopeGlobal, domain.ScopeTenant}, true) + + mgr, mgrErr := NewManager(ManagerConfig{Registry: reg, Store: store, History: history, Supervisor: spy, Builder: builder}) + require.NoError(t, mgrErr) + + entries, err := mgr.GetSettingSchema(context.Background()) + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, "ui.theme", entries[0].Key) + assert.Equal(t, domain.KindSetting, entries[0].Kind) +} + +func TestGetConfigSchema_RedactsSecretDefault(t *testing.T) { + t.Parallel() + + reg, store, history, spy, builder := testManagerDeps(t) + require.NoError(t, reg.Register(domain.KeyDef{ + Key: "auth.token", + EnvVar: "AUTH_TOKEN", + Kind: domain.KindConfig, + AllowedScopes: []domain.Scope{domain.ScopeGlobal}, + ValueType: domain.ValueTypeString, + DefaultValue: "my-token-value", + Secret: true, + RedactPolicy: domain.RedactFull, + ApplyBehavior: domain.ApplyLiveRead, + MutableAtRuntime: true, + Description: "auth token", + Group: "auth", + })) + + mgr, mgrErr := NewManager(ManagerConfig{Registry: reg, Store: store, History: history, Supervisor: spy, Builder: builder}) + require.NoError(t, mgrErr) + + entries, err := mgr.GetConfigSchema(context.Background()) + require.NoError(t, err) + + var found bool + + for _, entry := range entries { + if entry.Key == "auth.token" { + found = true + assert.Equal(t, "****", entry.DefaultValue) + assert.Equal(t, "AUTH_TOKEN", entry.EnvVar) + assert.Equal(t, domain.RedactFull, entry.RedactPolicy) + } + } + + assert.True(t, found, "expected auth.token in schema") +} + +func TestGetConfigSchema_ZeroValueRedactPolicyBecomesNone(t *testing.T) { + t.Parallel() + + reg, store, history, spy, builder := testManagerDeps(t) + require.NoError(t, reg.Register(domain.KeyDef{ + Key: "app.name", + EnvVar: "APP_NAME", + Kind: domain.KindConfig, + AllowedScopes: []domain.Scope{domain.ScopeGlobal}, + ValueType: domain.ValueTypeString, + DefaultValue: "service", + ApplyBehavior: domain.ApplyLiveRead, + MutableAtRuntime: true, + Description: "app name", + Group: "app", + })) + + mgr, mgrErr := NewManager(ManagerConfig{Registry: reg, Store: store, History: history, Supervisor: spy, Builder: builder}) + require.NoError(t, mgrErr) + + entries, err := mgr.GetConfigSchema(context.Background()) + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, domain.RedactNone, entries[0].RedactPolicy) +} diff --git a/commons/systemplane/service/manager_test.go b/commons/systemplane/service/manager_test.go index fd53c851..da335d4a 100644 --- a/commons/systemplane/service/manager_test.go +++ b/commons/systemplane/service/manager_test.go @@ -365,34 +365,6 @@ func TestManager_PatchSettings_ScopeValidation(t *testing.T) { assert.ErrorIs(t, err, domain.ErrScopeInvalid) } -func TestManager_GetSchema_RedactsSecretDefaults(t *testing.T) { - t.Parallel() - - reg, store, history, spy, builder := testManagerDeps(t) - require.NoError(t, reg.Register(domain.KeyDef{ - Key: "auth.secret", - Kind: domain.KindConfig, - AllowedScopes: []domain.Scope{domain.ScopeGlobal}, - ValueType: domain.ValueTypeString, - DefaultValue: "super-secret", - Secret: true, - RedactPolicy: domain.RedactFull, - ApplyBehavior: domain.ApplyLiveRead, - MutableAtRuntime: true, - Description: "secret", - Group: "auth", - })) - - mgr, mgrErr := NewManager(ManagerConfig{Registry: reg, Store: store, History: history, Supervisor: spy, Builder: builder}) - require.NoError(t, mgrErr) - - entries, err := mgr.GetConfigSchema(context.Background()) - require.NoError(t, err) - require.Len(t, entries, 1) - assert.Equal(t, "****", entries[0].DefaultValue) - assert.Equal(t, domain.RedactFull, entries[0].RedactPolicy) -} - func TestManager_GetHistory_RedactsSecrets(t *testing.T) { t.Parallel() diff --git a/commons/systemplane/service/manager_writes.go b/commons/systemplane/service/manager_writes.go index 5e078a74..63d07900 100644 --- a/commons/systemplane/service/manager_writes.go +++ b/commons/systemplane/service/manager_writes.go @@ -28,10 +28,8 @@ func (manager *defaultManager) PatchConfigs(ctx context.Context, req PatchReques return WriteResult{}, nil } - for _, op := range req.Ops { - if err := manager.validateConfigOp(op); err != nil { - return WriteResult{}, err - } + if err := manager.validateConfigOps(req.Ops); err != nil { + return WriteResult{}, err } if manager.configWriteValidator != nil { @@ -47,30 +45,19 @@ func (manager *defaultManager) PatchConfigs(ctx context.Context, req PatchReques } } - escalation, _, err := Escalate(manager.registry, req.Ops) + plan, err := manager.buildWritePlan(domain.KindConfig, domain.ScopeGlobal, "", req.Ops) if err != nil { libOpentelemetry.HandleSpanError(span, "escalate config patch", err) return WriteResult{}, fmt.Errorf("patch configs escalation: %w", err) } - target, err := domain.NewTarget(domain.KindConfig, domain.ScopeGlobal, "") - if err != nil { - libOpentelemetry.HandleSpanError(span, "build config target", err) - return WriteResult{}, fmt.Errorf("patch configs target: %w", err) - } - - revision, err := manager.store.Put(ctx, target, req.Ops, req.ExpectedRevision, req.Actor, req.Source) + result, err := manager.persistAndApplyWrite(ctx, plan, req) if err != nil { - libOpentelemetry.HandleSpanError(span, "persist config patch", err) - return WriteResult{}, fmt.Errorf("patch configs put: %w", err) - } - - if err := manager.applyEscalation(ctx, target, escalation); err != nil { - libOpentelemetry.HandleSpanError(span, "apply config escalation", err) - return WriteResult{}, fmt.Errorf("patch configs apply: %w", err) + libOpentelemetry.HandleSpanError(span, "persist/apply config patch", err) + return result, fmt.Errorf("patch configs write: %w", err) } - return WriteResult{Revision: revision}, nil + return result, nil } // PatchSettings validates and persists setting mutations for the provided subject. @@ -84,36 +71,23 @@ func (manager *defaultManager) PatchSettings(ctx context.Context, subject Subjec return WriteResult{}, nil } - for _, op := range req.Ops { - if err := manager.validateSettingOp(op, subject.Scope); err != nil { - return WriteResult{}, err - } + if err := manager.validateSettingOps(req.Ops, subject.Scope); err != nil { + return WriteResult{}, err } - escalation, _, err := Escalate(manager.registry, req.Ops) + plan, err := manager.buildWritePlan(domain.KindSetting, subject.Scope, subject.SubjectID, req.Ops) if err != nil { libOpentelemetry.HandleSpanError(span, "escalate settings patch", err) return WriteResult{}, fmt.Errorf("patch settings escalation: %w", err) } - target, err := domain.NewTarget(domain.KindSetting, subject.Scope, subject.SubjectID) - if err != nil { - libOpentelemetry.HandleSpanError(span, "build settings target", err) - return WriteResult{}, fmt.Errorf("patch settings target: %w", err) - } - - revision, err := manager.store.Put(ctx, target, req.Ops, req.ExpectedRevision, req.Actor, req.Source) + result, err := manager.persistAndApplyWrite(ctx, plan, req) if err != nil { - libOpentelemetry.HandleSpanError(span, "persist settings patch", err) - return WriteResult{}, fmt.Errorf("patch settings put: %w", err) - } - - if err := manager.applyEscalation(ctx, target, escalation); err != nil { - libOpentelemetry.HandleSpanError(span, "apply settings escalation", err) - return WriteResult{}, fmt.Errorf("patch settings apply: %w", err) + libOpentelemetry.HandleSpanError(span, "persist/apply settings patch", err) + return result, fmt.Errorf("patch settings write: %w", err) } - return WriteResult{Revision: revision}, nil + return result, nil } // ApplyChangeSignal applies a precomputed runtime escalation from an external source. diff --git a/commons/systemplane/service/manager_writes_helpers.go b/commons/systemplane/service/manager_writes_helpers.go index 09bc1ab8..d92e8327 100644 --- a/commons/systemplane/service/manager_writes_helpers.go +++ b/commons/systemplane/service/manager_writes_helpers.go @@ -11,6 +11,11 @@ import ( "github.com/LerianStudio/lib-commons/v4/commons/systemplane/ports" ) +type writePlan struct { + target domain.Target + escalation domain.ApplyBehavior +} + func (manager *defaultManager) previewConfigSnapshot(ctx context.Context, ops []ports.WriteOp) (domain.Snapshot, error) { current := cloneSnapshot(manager.supervisor.Snapshot()) if current.BuiltAt.IsZero() || current.Configs == nil { @@ -35,7 +40,7 @@ func (manager *defaultManager) previewConfigSnapshot(ctx context.Context, ops [] ev := current.Configs[op.Key] ev.Key = def.Key ev.Default = def.DefaultValue - ev.Redacted = def.RedactPolicy != domain.RedactNone + ev.Redacted = isRedacted(def) if op.Reset || domain.IsNilValue(op.Value) { ev.Value = def.DefaultValue @@ -56,8 +61,58 @@ func (manager *defaultManager) previewConfigSnapshot(ctx context.Context, ops [] return current, nil } -// PatchSettings validates the mutations, persists them, and applies the -// escalation behavior. +func (manager *defaultManager) validateConfigOps(ops []ports.WriteOp) error { + for _, op := range ops { + if err := manager.validateConfigOp(op); err != nil { + return err + } + } + + return nil +} + +func (manager *defaultManager) validateSettingOps(ops []ports.WriteOp, scope domain.Scope) error { + for _, op := range ops { + if err := manager.validateSettingOp(op, scope); err != nil { + return err + } + } + + return nil +} + +func (manager *defaultManager) buildWritePlan(kind domain.Kind, scope domain.Scope, subjectID string, ops []ports.WriteOp) (writePlan, error) { + escalation, _, err := Escalate(manager.registry, ops) + if err != nil { + return writePlan{}, fmt.Errorf("escalate: %w", err) + } + + target, err := domain.NewTarget(kind, scope, subjectID) + if err != nil { + return writePlan{}, fmt.Errorf("build target: %w", err) + } + + return writePlan{target: target, escalation: escalation}, nil +} + +func (manager *defaultManager) persistAndApplyWrite( + ctx context.Context, + plan writePlan, + req PatchRequest, +) (WriteResult, error) { + revision, err := manager.store.Put(ctx, plan.target, req.Ops, req.ExpectedRevision, req.Actor, req.Source) + if err != nil { + return WriteResult{}, fmt.Errorf("persist patch: %w", err) + } + + if err := manager.applyEscalation(ctx, plan.target, plan.escalation); err != nil { + // Return the persisted revision so callers can reconcile/retry safely + // even when the post-write escalation (e.g., reload signal) fails. + return WriteResult{Revision: revision}, fmt.Errorf("apply escalation: %w", err) + } + + return WriteResult{Revision: revision}, nil +} func (manager *defaultManager) validateConfigOp(op ports.WriteOp) error { def, ok := manager.registry.Get(op.Key) @@ -157,9 +212,6 @@ func (manager *defaultManager) applyWithSnapshot( return nil } -// ApplyChangeSignal applies an externally produced change signal using the -// signal's escalation behavior or a safe rebuild fallback. - func (manager *defaultManager) buildActiveSnapshot(ctx context.Context, target domain.Target) (domain.Snapshot, error) { current := cloneSnapshot(manager.supervisor.Snapshot()) diff --git a/commons/systemplane/service/shutdown.go b/commons/systemplane/service/shutdown.go new file mode 100644 index 00000000..6953eac7 --- /dev/null +++ b/commons/systemplane/service/shutdown.go @@ -0,0 +1,96 @@ +// Copyright 2025 Lerian Studio. + +package service + +import ( + "context" + "errors" + "fmt" +) + +// ShutdownSequence performs the canonical 5-step systemplane shutdown. +// All steps execute regardless of errors. Nil fields are skipped. +type ShutdownSequence struct { + // PreventMutations stops accepting config writes (e.g., configManager.Stop()). + PreventMutations func() + // CancelFeed cancels the change feed subscription context. + CancelFeed context.CancelFunc + // StopSupervisor stops the supervisor and closes the active bundle. + StopSupervisor func(context.Context) error + // CloseBackend closes the systemplane store connection. + CloseBackend func() error + // StopWorkers stops background workers. + StopWorkers func() +} + +// Execute runs all shutdown steps in canonical order. +// Returns combined errors via errors.Join. All steps run even if earlier ones fail or panic. +func (s *ShutdownSequence) Execute(ctx context.Context) error { + if s == nil { + return nil + } + + if ctx == nil { + ctx = context.Background() + } + + var errs []error + + appendPanic := func(step string) { + if recovered := recover(); recovered != nil { + errs = append(errs, fmt.Errorf("%s panic: %v", step, recovered)) + } + } + appendError := func(step string, err error) { + if err != nil { + errs = append(errs, fmt.Errorf("%s: %w", step, err)) + } + } + + // Step 1: Prevent mutations (fire-and-forget). + if s.PreventMutations != nil { + func() { + defer appendPanic("prevent mutations") + + s.PreventMutations() + }() + } + + // Step 2: Cancel feed (fire-and-forget). + if s.CancelFeed != nil { + func() { + defer appendPanic("cancel feed") + + s.CancelFeed() + }() + } + + // Step 3: Stop supervisor (collects error). + if s.StopSupervisor != nil { + func() { + defer appendPanic("stop supervisor") + + appendError("stop supervisor", s.StopSupervisor(ctx)) + }() + } + + // Step 4: Close backend (collects error). + if s.CloseBackend != nil { + func() { + defer appendPanic("close backend") + + appendError("close backend", s.CloseBackend()) + }() + } + + // Step 5: Stop workers (fire-and-forget). + if s.StopWorkers != nil { + func() { + defer appendPanic("stop workers") + + s.StopWorkers() + }() + } + + return errors.Join(errs...) +} diff --git a/commons/systemplane/service/shutdown_test.go b/commons/systemplane/service/shutdown_test.go new file mode 100644 index 00000000..bd57a78d --- /dev/null +++ b/commons/systemplane/service/shutdown_test.go @@ -0,0 +1,227 @@ +//go:build unit + +// Copyright 2025 Lerian Studio. + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestShutdownSequence_Execute(t *testing.T) { + t.Parallel() + + errSupervisor := errors.New("supervisor failed") + errBackend := errors.New("backend failed") + + tests := []struct { + name string + allNil bool + supervisorErr error + backendErr error + wantOrder []string // expected call order recorded via append + wantErr bool + checkErr func(t *testing.T, err error) + }{ + { + name: "all steps called in canonical order", + wantOrder: []string{"prevent", "feed", "supervisor", "backend", "workers"}, + }, + { + name: "all nil steps do not panic", + allNil: true, + wantOrder: nil, + wantErr: false, + }, + { + name: "supervisor error is collected", + supervisorErr: errSupervisor, + wantOrder: []string{"prevent", "feed", "supervisor", "backend", "workers"}, + wantErr: true, + checkErr: func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, errSupervisor) + assert.ErrorContains(t, err, "stop supervisor") + }, + }, + { + name: "backend error is collected", + backendErr: errBackend, + wantOrder: []string{"prevent", "feed", "supervisor", "backend", "workers"}, + wantErr: true, + checkErr: func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, errBackend) + assert.ErrorContains(t, err, "close backend") + }, + }, + { + name: "both errors combined via errors.Join", + supervisorErr: errSupervisor, + backendErr: errBackend, + wantOrder: []string{"prevent", "feed", "supervisor", "backend", "workers"}, + wantErr: true, + checkErr: func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, errSupervisor) + assert.ErrorIs(t, err, errBackend) + assert.ErrorContains(t, err, "stop supervisor") + assert.ErrorContains(t, err, "close backend") + }, + }, + { + name: "all steps run even when supervisor fails", + supervisorErr: errSupervisor, + wantOrder: []string{"prevent", "feed", "supervisor", "backend", "workers"}, + wantErr: true, + checkErr: func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, errSupervisor) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var order []string + + seq := buildShutdownSequence(&order, tt.allNil, tt.supervisorErr, tt.backendErr) + + err := seq.Execute(context.Background()) + + if tt.wantOrder != nil { + assert.Equal(t, tt.wantOrder, order) + } else { + assert.Empty(t, order) + } + + if tt.wantErr { + require.Error(t, err) + + if tt.checkErr != nil { + tt.checkErr(t, err) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func buildShutdownSequence(order *[]string, allNil bool, errSup, errBack error) ShutdownSequence { + if allNil { + return ShutdownSequence{} + } + + seq := ShutdownSequence{ + PreventMutations: func() { *order = append(*order, "prevent") }, + CancelFeed: func() { *order = append(*order, "feed") }, + StopSupervisor: func(_ context.Context) error { + *order = append(*order, "supervisor") + return nil + }, + CloseBackend: func() error { + *order = append(*order, "backend") + return nil + }, + StopWorkers: func() { *order = append(*order, "workers") }, + } + + if errSup != nil { + seq.StopSupervisor = func(_ context.Context) error { + *order = append(*order, "supervisor") + return errSup + } + } + + if errBack != nil { + seq.CloseBackend = func() error { + *order = append(*order, "backend") + return errBack + } + } + + return seq +} + +func TestShutdownSequence_Execute_PartialNilSteps(t *testing.T) { + t.Parallel() + + var order []string + + seq := ShutdownSequence{ + PreventMutations: nil, // skipped + CancelFeed: func() { order = append(order, "feed") }, + StopSupervisor: nil, // skipped + CloseBackend: func() error { order = append(order, "backend"); return nil }, + StopWorkers: nil, // skipped + } + + err := seq.Execute(context.Background()) + + assert.NoError(t, err) + assert.Equal(t, []string{"feed", "backend"}, order) +} + +func TestShutdownSequence_Execute_NilReceiver(t *testing.T) { + t.Parallel() + + var seq *ShutdownSequence + assert.NoError(t, seq.Execute(context.Background())) +} + +func TestShutdownSequence_Execute_NilContextUsesBackground(t *testing.T) { + t.Parallel() + + called := false + seq := ShutdownSequence{ + StopSupervisor: func(ctx context.Context) error { + called = true + require.NotNil(t, ctx) + assert.NoError(t, ctx.Err()) + return nil + }, + } + + assert.NoError(t, seq.Execute(nil)) + assert.True(t, called) +} + +func TestShutdownSequence_Execute_PanicsDoNotAbortLaterSteps(t *testing.T) { + t.Parallel() + + var order []string + errBackend := errors.New("backend failed") + + seq := ShutdownSequence{ + PreventMutations: func() { + order = append(order, "prevent") + panic("boom") + }, + CancelFeed: func() { order = append(order, "feed") }, + StopSupervisor: func(_ context.Context) error { + order = append(order, "supervisor") + return nil + }, + CloseBackend: func() error { + order = append(order, "backend") + return errBackend + }, + StopWorkers: func() { order = append(order, "workers") }, + } + + err := seq.Execute(context.Background()) + + require.Error(t, err) + assert.Equal(t, []string{"prevent", "feed", "supervisor", "backend", "workers"}, order) + assert.ErrorContains(t, err, "prevent mutations panic") + assert.ErrorContains(t, err, "close backend") + assert.ErrorIs(t, err, errBackend) +} diff --git a/commons/systemplane/service/snapshot_builder.go b/commons/systemplane/service/snapshot_builder.go index afaa9270..ad1b2974 100644 --- a/commons/systemplane/service/snapshot_builder.go +++ b/commons/systemplane/service/snapshot_builder.go @@ -54,24 +54,14 @@ func (builder *SnapshotBuilder) BuildConfigs(ctx context.Context) (map[string]do defer span.End() defs := builder.registry.List(domain.KindConfig) - effective := initDefaults(defs) - - target, err := domain.NewTarget(domain.KindConfig, domain.ScopeGlobal, "") - if err != nil { - libOpentelemetry.HandleSpanError(span, "build config target", err) - return nil, domain.RevisionZero, fmt.Errorf("build config target: %w", err) - } - result, err := builder.store.Get(ctx, target) + effective, revision, err := builder.buildEffectiveValues(ctx, defs, domain.KindConfig, domain.ScopeGlobal, "", "global-override") if err != nil { - libOpentelemetry.HandleSpanError(span, "load config overrides", err) - return nil, domain.RevisionZero, fmt.Errorf("get config overrides: %w", err) + libOpentelemetry.HandleSpanError(span, "build configs", err) + return nil, domain.RevisionZero, fmt.Errorf("build configs: %w", err) } - applyOverrides(effective, result.Entries, "global-override") - setRevision(effective, result.Revision) - - return effective, result.Revision, nil + return effective, revision, nil } // BuildGlobalSettings builds global settings using defaults plus global @@ -83,24 +73,14 @@ func (builder *SnapshotBuilder) BuildGlobalSettings(ctx context.Context) (map[st defer span.End() defs := filterDefsByScope(builder.registry.List(domain.KindSetting), domain.ScopeGlobal) - effective := initDefaults(defs) - - target, err := domain.NewTarget(domain.KindSetting, domain.ScopeGlobal, "") - if err != nil { - libOpentelemetry.HandleSpanError(span, "build global settings target", err) - return nil, domain.RevisionZero, fmt.Errorf("build global setting target: %w", err) - } - result, err := builder.store.Get(ctx, target) + effective, revision, err := builder.buildEffectiveValues(ctx, defs, domain.KindSetting, domain.ScopeGlobal, "", "global-override") if err != nil { - libOpentelemetry.HandleSpanError(span, "load global setting overrides", err) - return nil, domain.RevisionZero, fmt.Errorf("get global setting overrides: %w", err) + libOpentelemetry.HandleSpanError(span, "build global settings", err) + return nil, domain.RevisionZero, fmt.Errorf("build global settings: %w", err) } - applyOverrides(effective, result.Entries, "global-override") - setRevision(effective, result.Revision) - - return effective, result.Revision, nil + return effective, revision, nil } // BuildSettings builds effective settings for the requested subject. @@ -125,24 +105,14 @@ func (builder *SnapshotBuilder) buildTenantSettings(ctx context.Context, tenantI defs := filterDefsByScope(builder.registry.List(domain.KindSetting), domain.ScopeTenant) effective := initDefaults(defs) - globalTarget, err := domain.NewTarget(domain.KindSetting, domain.ScopeGlobal, "") - if err != nil { - return nil, domain.RevisionZero, fmt.Errorf("build global setting target: %w", err) - } - - globalResult, err := builder.store.Get(ctx, globalTarget) + globalResult, err := builder.readTarget(ctx, domain.KindSetting, domain.ScopeGlobal, "") if err != nil { return nil, domain.RevisionZero, fmt.Errorf("get global setting overrides: %w", err) } applyOverrides(effective, globalResult.Entries, "global-override") - tenantTarget, err := domain.NewTarget(domain.KindSetting, domain.ScopeTenant, tenantID) - if err != nil { - return nil, domain.RevisionZero, fmt.Errorf("build tenant setting target: %w", err) - } - - tenantResult, err := builder.store.Get(ctx, tenantTarget) + tenantResult, err := builder.readTarget(ctx, domain.KindSetting, domain.ScopeTenant, tenantID) if err != nil { return nil, domain.RevisionZero, fmt.Errorf("get tenant setting overrides: %w", err) } @@ -156,6 +126,46 @@ func (builder *SnapshotBuilder) buildTenantSettings(ctx context.Context, tenantI return effective, effectiveRevision, nil } +func (builder *SnapshotBuilder) buildEffectiveValues( + ctx context.Context, + defs []domain.KeyDef, + kind domain.Kind, + scope domain.Scope, + subjectID string, + overrideSource string, +) (map[string]domain.EffectiveValue, domain.Revision, error) { + effective := initDefaults(defs) + + result, err := builder.readTarget(ctx, kind, scope, subjectID) + if err != nil { + return nil, domain.RevisionZero, err + } + + applyOverrides(effective, result.Entries, overrideSource) + setRevision(effective, result.Revision) + + return effective, result.Revision, nil +} + +func (builder *SnapshotBuilder) readTarget( + ctx context.Context, + kind domain.Kind, + scope domain.Scope, + subjectID string, +) (ports.ReadResult, error) { + target, err := domain.NewTarget(kind, scope, subjectID) + if err != nil { + return ports.ReadResult{}, fmt.Errorf("build target: %w", err) + } + + result, err := builder.store.Get(ctx, target) + if err != nil { + return ports.ReadResult{}, fmt.Errorf("get overrides: %w", err) + } + + return result, nil +} + // BuildFull builds a complete snapshot with configs, global settings, and any // requested tenant settings. func (builder *SnapshotBuilder) BuildFull(ctx context.Context, tenantIDs ...string) (domain.Snapshot, error) { @@ -212,7 +222,7 @@ func initDefaults(defs []domain.KeyDef) map[string]domain.EffectiveValue { Override: nil, Source: "default", Revision: domain.RevisionZero, - Redacted: def.RedactPolicy != domain.RedactNone, + Redacted: isRedacted(def), } } diff --git a/commons/systemplane/service/snapshot_builder_test.go b/commons/systemplane/service/snapshot_builder_test.go index 4e28dd18..5c5346e3 100644 --- a/commons/systemplane/service/snapshot_builder_test.go +++ b/commons/systemplane/service/snapshot_builder_test.go @@ -212,3 +212,17 @@ func TestBuildSettings_RetainsRawSecretValuesInSnapshot(t *testing.T) { assert.Equal(t, "changeme", settings["db.password"].Value) assert.True(t, settings["db.password"].Redacted) } + +func TestInitDefaults_ZeroValueRedactPolicyIsNotMarkedRedacted(t *testing.T) { + t.Parallel() + + values := initDefaults([]domain.KeyDef{{ + Key: "app.name", + ValueType: domain.ValueTypeString, + ApplyBehavior: domain.ApplyLiveRead, + }}) + + val, ok := values["app.name"] + require.True(t, ok, "expected key \"app.name\" to exist in defaults") + assert.False(t, val.Redacted) +} diff --git a/commons/systemplane/service/supervisor.go b/commons/systemplane/service/supervisor.go index 3b56cce6..5c07434e 100644 --- a/commons/systemplane/service/supervisor.go +++ b/commons/systemplane/service/supervisor.go @@ -92,13 +92,16 @@ func NewSupervisor(cfg SupervisorConfig) (Supervisor, error) { }, nil } -type bundleHolder struct { - bundle domain.RuntimeBundle +// supervisorState holds the immutable (snapshot, bundle) pair that readers +// observe via a single atomic pointer. Replacing two separate atomics with +// one guarantees readers always see a consistent pair. +type supervisorState struct { + snapshot domain.Snapshot + bundle domain.RuntimeBundle } type defaultSupervisor struct { - snapshot atomic.Pointer[domain.Snapshot] - bundle atomic.Pointer[bundleHolder] + state atomic.Pointer[supervisorState] mu sync.Mutex builder *SnapshotBuilder factory ports.BundleFactory @@ -110,22 +113,22 @@ type defaultSupervisor struct { // Current returns the currently active runtime bundle. func (supervisor *defaultSupervisor) Current() domain.RuntimeBundle { - holder := supervisor.bundle.Load() - if holder == nil || isNilRuntimeBundle(holder.bundle) { + st := supervisor.state.Load() + if st == nil || isNilRuntimeBundle(st.bundle) { return nil } - return holder.bundle + return st.bundle } // Snapshot returns the latest published snapshot. func (supervisor *defaultSupervisor) Snapshot() domain.Snapshot { - snap := supervisor.snapshot.Load() - if snap == nil { + st := supervisor.state.Load() + if st == nil { return domain.Snapshot{} } - return *snap + return st.snapshot } // PublishSnapshot publishes a snapshot without rebuilding bundles. @@ -144,7 +147,14 @@ func (supervisor *defaultSupervisor) PublishSnapshot(ctx context.Context, snap d supervisor.mu.Lock() defer supervisor.mu.Unlock() - supervisor.snapshot.Store(&snap) + prev := supervisor.state.Load() + + newState := &supervisorState{snapshot: snap} + if prev != nil { + newState.bundle = prev.bundle + } + + supervisor.state.Store(newState) return nil } @@ -164,21 +174,19 @@ func (supervisor *defaultSupervisor) ReconcileCurrent(ctx context.Context, snap supervisor.mu.Lock() defer supervisor.mu.Unlock() - holder := supervisor.bundle.Load() - if holder == nil || isNilRuntimeBundle(holder.bundle) { + prev := supervisor.state.Load() + if prev == nil || isNilRuntimeBundle(prev.bundle) { libOpentelemetry.HandleSpanError(span, "missing current bundle", domain.ErrNoCurrentBundle) return domain.ErrNoCurrentBundle } - previous := supervisor.snapshot.Load() - supervisor.snapshot.Store(&snap) + // Optimistically swap to new snapshot, keeping the same bundle. + supervisor.state.Store(&supervisorState{snapshot: snap, bundle: prev.bundle}) - currentBundle := holder.bundle for _, reconciler := range supervisor.reconcilers { - if err := reconciler.Reconcile(ctx, currentBundle, currentBundle, snap); err != nil { - if previous != nil { - supervisor.snapshot.Store(previous) - } + if err := reconciler.Reconcile(ctx, prev.bundle, prev.bundle, snap); err != nil { + // Rollback: restore previous state atomically. + supervisor.state.Store(prev) libOpentelemetry.HandleSpanError(span, "reconcile current bundle", err) @@ -204,33 +212,20 @@ func (supervisor *defaultSupervisor) Reload(ctx context.Context, reason string, supervisor.mu.Lock() defer supervisor.mu.Unlock() - tenantIDs := mergeUniqueTenantIDs(cachedTenantIDs(supervisor.snapshot.Load()), extraTenantIDs) - - snap, err := supervisor.builder.BuildFull(ctx, tenantIDs...) - if err != nil { - libOpentelemetry.HandleSpanError(span, "build full snapshot", err) - return fmt.Errorf("reload: %w: %w", domain.ErrSnapshotBuildFailed, err) + var prevSnap *domain.Snapshot + if st := supervisor.state.Load(); st != nil { + prevSnap = &st.snapshot } - prevSnap := supervisor.snapshot.Load() - prevHolder := supervisor.bundle.Load() + tenantIDs := mergeUniqueTenantIDs(cachedTenantIDs(prevSnap), extraTenantIDs) - var previousBundle domain.RuntimeBundle - if prevHolder != nil { - previousBundle = prevHolder.bundle - } - - // Try incremental build first if the factory supports it and we have a - // previous snapshot+bundle. This reuses unchanged infrastructure components - // (Postgres, Redis, etc.) instead of rebuilding everything. - // Falls back to full build on failure or when incremental is not available. - candidate, strategy, err := supervisor.buildBundle(ctx, snap, previousBundle, prevSnap) + build, err := supervisor.prepareReloadBuild(ctx, tenantIDs) if err != nil { libOpentelemetry.HandleSpanError(span, "build runtime bundle", err) - return fmt.Errorf("reload: %w: %w", domain.ErrBundleBuildFailed, err) + return err } - if isNilRuntimeBundle(candidate) { + if isNilRuntimeBundle(build.candidate) { libOpentelemetry.HandleSpanError(span, "nil runtime bundle", domain.ErrBundleBuildFailed) return fmt.Errorf("reload: %w: nil runtime bundle", domain.ErrBundleBuildFailed) } @@ -240,34 +235,13 @@ func (supervisor *defaultSupervisor) Reload(ctx context.Context, reason string, // state corruption when incremental builds nil-out transferred pointers in // the previous bundle — if we stored the candidate first and a reconciler // failed, the "rollback" would restore a gutted previous bundle. - for _, reconciler := range supervisor.reconcilers { - if err := reconciler.Reconcile(ctx, previousBundle, candidate, snap); err != nil { - discardFailedCandidate(ctx, candidate, strategy) - - libOpentelemetry.HandleSpanError(span, "reconcile candidate bundle", err) - - return fmt.Errorf("reload: %s: %w: %w", reconciler.Name(), domain.ErrReconcileFailed, err) - } + if err := supervisor.reconcileCandidateBundle(ctx, build); err != nil { + libOpentelemetry.HandleSpanError(span, "reconcile candidate bundle", err) + return err } // All reconcilers passed — commit atomically. - supervisor.snapshot.Store(&snap) - supervisor.bundle.Store(&bundleHolder{bundle: candidate}) - - if adopter, ok := candidate.(resourceAdopter); ok && !isNilRuntimeBundle(previousBundle) { - adopter.AdoptResourcesFrom(previousBundle) - } - - if supervisor.observer != nil { - supervisor.observer(ReloadEvent{Strategy: strategy, Reason: reason, Snapshot: snap, Bundle: candidate}) - } - - // Close previous AFTER commit so transferred components are not torn down - // while still referenced by the now-active candidate bundle or by external - // runtime delegates that are repointed by the observer. - if !isNilRuntimeBundle(previousBundle) { - _ = previousBundle.Close(ctx) - } + supervisor.commitReload(ctx, reason, build) return nil } @@ -284,9 +258,9 @@ func (supervisor *defaultSupervisor) Stop(ctx context.Context) error { supervisor.mu.Lock() defer supervisor.mu.Unlock() - holder := supervisor.bundle.Load() - if holder != nil && !isNilRuntimeBundle(holder.bundle) { - if err := holder.bundle.Close(ctx); err != nil { + st := supervisor.state.Load() + if st != nil && !isNilRuntimeBundle(st.bundle) { + if err := st.bundle.Close(ctx); err != nil { libOpentelemetry.HandleSpanError(span, "close current bundle", err) return fmt.Errorf("stop close current bundle: %w", err) } diff --git a/commons/systemplane/service/supervisor_helpers.go b/commons/systemplane/service/supervisor_helpers.go index 80b46cb5..32df7b5f 100644 --- a/commons/systemplane/service/supervisor_helpers.go +++ b/commons/systemplane/service/supervisor_helpers.go @@ -8,6 +8,7 @@ import ( "slices" "sort" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" libCommons "github.com/LerianStudio/lib-commons/v4/commons" @@ -24,6 +25,25 @@ type rollbackDiscarder interface { Discard(ctx context.Context) error } +type reloadBuild struct { + snapshot domain.Snapshot + previousSnap *domain.Snapshot + previousBundle domain.RuntimeBundle + candidate domain.RuntimeBundle + strategy BuildStrategy +} + +// recordCleanupError records a best-effort cleanup error to the active span. +// Cleanup failures are not actionable by callers but must not go invisible. +func recordCleanupError(ctx context.Context, phase string, err error) { + span := trace.SpanFromContext(ctx) + if span.IsRecording() { + span.RecordError(err, trace.WithAttributes( + attribute.String("cleanup.phase", phase), + )) + } +} + func discardFailedCandidate(ctx context.Context, candidate domain.RuntimeBundle, strategy BuildStrategy) { if isNilRuntimeBundle(candidate) { return @@ -38,18 +58,24 @@ func discardFailedCandidate(ctx context.Context, candidate domain.RuntimeBundle, return } - _ = discarder.Discard(ctx) + if err := discarder.Discard(ctx); err != nil { + recordCleanupError(ctx, "discard_incremental_candidate", err) + } return } if discarder, ok := candidate.(rollbackDiscarder); ok { - _ = discarder.Discard(ctx) + if err := discarder.Discard(ctx); err != nil { + recordCleanupError(ctx, "discard_full_candidate", err) + } return } - _ = candidate.Close(ctx) + if err := candidate.Close(ctx); err != nil { + recordCleanupError(ctx, "close_failed_candidate", err) + } } func startSupervisorSpan(ctx context.Context, operation string) (context.Context, trace.Span) { @@ -94,9 +120,10 @@ func (supervisor *defaultSupervisor) buildBundle( } // Discard partially-built candidate to prevent resource leaks. + // Use the incremental discard path so shared/adopted resources are + // released safely rather than double-closed via a plain Close(). if err != nil && !isNilRuntimeBundle(candidate) { - // RuntimeBundle.Close(ctx) is the contract for releasing held resources. - _ = candidate.Close(ctx) + discardFailedCandidate(ctx, candidate, BuildStrategyIncremental) } // Incremental build failed — fall through to full build. } @@ -110,6 +137,83 @@ func (supervisor *defaultSupervisor) buildBundle( return bundle, BuildStrategyFull, nil } +// prepareReloadBuild builds a new snapshot and candidate bundle for a reload. +// +// Concurrency safety: BuildIncremental receives a reference to the live +// previousBundle. Its contract is to create a candidate that SHARES resources +// (read-only) from the previous bundle — it must NOT mutate the previous +// bundle's internal pointers. The actual resource transfer (nil-ing +// transferred pointers in previousBundle) happens later in commitReload via +// AdoptResourcesFrom, which runs AFTER the atomic state swap. At that point +// Current() already returns the new candidate, so concurrent readers never +// observe the mutation. +func (supervisor *defaultSupervisor) prepareReloadBuild(ctx context.Context, tenantIDs []string) (reloadBuild, error) { + snap, err := supervisor.builder.BuildFull(ctx, tenantIDs...) + if err != nil { + return reloadBuild{}, fmt.Errorf("reload: %w: %w", domain.ErrSnapshotBuildFailed, err) + } + + var prevSnap *domain.Snapshot + + var previousBundle domain.RuntimeBundle + + if st := supervisor.state.Load(); st != nil { + prevSnap = &st.snapshot + previousBundle = st.bundle + } + + candidate, strategy, err := supervisor.buildBundle(ctx, snap, previousBundle, prevSnap) + if err != nil { + return reloadBuild{}, fmt.Errorf("reload: %w: %w", domain.ErrBundleBuildFailed, err) + } + + return reloadBuild{ + snapshot: snap, + previousSnap: prevSnap, + previousBundle: previousBundle, + candidate: candidate, + strategy: strategy, + }, nil +} + +func (supervisor *defaultSupervisor) reconcileCandidateBundle(ctx context.Context, build reloadBuild) error { + for _, reconciler := range supervisor.reconcilers { + if err := reconciler.Reconcile(ctx, build.previousBundle, build.candidate, build.snapshot); err != nil { + discardFailedCandidate(ctx, build.candidate, build.strategy) + + return fmt.Errorf("reload: %s: %w: %w", reconciler.Name(), domain.ErrReconcileFailed, err) + } + } + + return nil +} + +func (supervisor *defaultSupervisor) commitReload(ctx context.Context, reason string, build reloadBuild) { + supervisor.state.Store(&supervisorState{ + snapshot: build.snapshot, + bundle: build.candidate, + }) + + if adopter, ok := build.candidate.(resourceAdopter); ok && !isNilRuntimeBundle(build.previousBundle) { + adopter.AdoptResourcesFrom(build.previousBundle) + } + + if supervisor.observer != nil { + supervisor.observer(ReloadEvent{ + Strategy: build.strategy, + Reason: reason, + Snapshot: build.snapshot, + Bundle: build.candidate, + }) + } + + if !isNilRuntimeBundle(build.previousBundle) { + if err := build.previousBundle.Close(ctx); err != nil { + recordCleanupError(ctx, "close_previous_bundle", err) + } + } +} + // sortReconcilersByPhase returns a copy of the reconciler slice sorted by // phase in ascending order (StateSync → Validation → SideEffect). Reconcilers // within the same phase retain their original relative order (stable sort). @@ -118,7 +222,16 @@ func sortReconcilersByPhase(reconcilers []ports.BundleReconciler) []ports.Bundle copy(sorted, reconcilers) slices.SortStableFunc(sorted, func(a, b ports.BundleReconciler) int { - return int(a.Phase()) - int(b.Phase()) + ap, bp := int(a.Phase()), int(b.Phase()) + if ap < bp { + return -1 + } + + if ap > bp { + return 1 + } + + return 0 }) return sorted @@ -134,10 +247,16 @@ func mergeUniqueTenantIDs(base, extra []string) []string { seen := make(map[string]struct{}, len(base)+len(extra)) + // Build result from a fresh slice to avoid mutating the caller's base. + result := make([]string, 0, len(base)+len(extra)) + for _, id := range base { seen[id] = struct{}{} + result = append(result, id) } + added := false + for _, id := range extra { if id == "" { continue @@ -145,13 +264,20 @@ func mergeUniqueTenantIDs(base, extra []string) []string { if _, exists := seen[id]; !exists { seen[id] = struct{}{} - base = append(base, id) + result = append(result, id) + added = true } } - sort.Strings(base) + // If nothing new was added, return the original base to preserve + // its nil/empty semantics (callers may check for nil). + if !added { + return base + } + + sort.Strings(result) - return base + return result } func cachedTenantIDs(snapshot *domain.Snapshot) []string { diff --git a/commons/systemplane/swagger/spec.json b/commons/systemplane/swagger/spec.json index a699fa39..1f65f181 100644 --- a/commons/systemplane/swagger/spec.json +++ b/commons/systemplane/swagger/spec.json @@ -749,6 +749,10 @@ "type": "string", "description": "Configuration or setting key name" }, + "envVar": { + "type": "string", + "description": "Backing environment variable name when applicable" + }, "kind": { "type": "string", "description": "Key kind (config or setting)" @@ -779,6 +783,10 @@ "type": "boolean", "description": "Whether the value is sensitive and should be redacted in responses" }, + "redactPolicy": { + "type": "string", + "description": "Effective redaction policy exposed by the schema" + }, "description": { "type": "string", "description": "Human-readable description of the key" diff --git a/commons/systemplane/swagger/swagger.go b/commons/systemplane/swagger/swagger.go index 241a8576..0df82ea2 100644 --- a/commons/systemplane/swagger/swagger.go +++ b/commons/systemplane/swagger/swagger.go @@ -30,7 +30,9 @@ func Spec() json.RawMessage { // Merge semantics: // - paths: systemplane paths are added; on key conflict systemplane wins. // - definitions: systemplane definitions are added; on key conflict systemplane wins. -// - tags: systemplane tags are appended to existing tags. +// - tags: systemplane tags are merged with deduplication by tag name. Tags +// from the source with a name already present in the destination are +// skipped. Duplicate names within the source are also collapsed. // // The function does not modify the target slice; it returns a new byte slice. func MergeInto(target []byte) ([]byte, error) { @@ -102,9 +104,11 @@ func mergeObjectField(field string, dst, src map[string]json.RawMessage) error { return nil } -// mergeTags appends systemplane tags to the target's existing tags array. -// Duplicate detection is not performed — the consumer is expected to start -// with a spec that does not already contain systemplane tags. +// mergeTags merges systemplane tags into the target's existing tags array with +// deduplication by tag name. Tags from the source whose name already exists in +// the destination are skipped; duplicate names within the source are also +// collapsed. Tag order from the destination is preserved; new source tags are +// appended in their original order. func mergeTags(dst, src map[string]json.RawMessage) error { srcTags, ok := src["tags"] if !ok { @@ -124,7 +128,41 @@ func mergeTags(dst, src map[string]json.RawMessage) error { } } - dstArr = append(dstArr, srcArr...) + // Build a set of existing tag names for deduplication. + existing := make(map[string]bool, len(dstArr)) + for _, raw := range dstArr { + var tag struct { + Name string `json:"name"` + } + if json.Unmarshal(raw, &tag) == nil && tag.Name != "" { + existing[tag.Name] = true + } + } + + // Only append tags whose name is not already present. Also deduplicate + // within srcArr itself by marking each appended name. Malformed tags + // and tags with empty names are rejected to prevent spec corruption. + for _, raw := range srcArr { + var tag struct { + Name string `json:"name"` + } + + if err := json.Unmarshal(raw, &tag); err != nil { + return fmt.Errorf("swagger merge: unmarshal src tag entry: %w", err) + } + + if tag.Name == "" { + continue + } + + if existing[tag.Name] { + continue + } + + existing[tag.Name] = true + + dstArr = append(dstArr, raw) + } merged, err := json.Marshal(dstArr) if err != nil { diff --git a/commons/systemplane/swagger/swagger_test.go b/commons/systemplane/swagger/swagger_test.go index 61d1a037..cb1c5dcb 100644 --- a/commons/systemplane/swagger/swagger_test.go +++ b/commons/systemplane/swagger/swagger_test.go @@ -1,3 +1,5 @@ +//go:build unit + // Copyright 2025 Lerian Studio. package swagger @@ -91,6 +93,23 @@ func TestSpec_ContainsAllDefinitions(t *testing.T) { "spec must contain exactly %d definitions", len(expectedDefinitions)) } +func TestSpec_SchemaEntryDTOContainsMetadataFields(t *testing.T) { + raw := Spec() + + var doc struct { + Definitions map[string]struct { + Properties map[string]json.RawMessage `json:"properties"` + } `json:"definitions"` + } + + require.NoError(t, json.Unmarshal(raw, &doc)) + + schemaEntry, ok := doc.Definitions["systemplane.SchemaEntryDTO"] + require.True(t, ok) + assert.Contains(t, schemaEntry.Properties, "envVar") + assert.Contains(t, schemaEntry.Properties, "redactPolicy") +} + func TestMergeInto_EmptyTarget(t *testing.T) { target := []byte(`{"swagger":"2.0","info":{"title":"Test","version":"1.0"},"paths":{}}`) diff --git a/commons/tenant-manager/client/client.go b/commons/tenant-manager/client/client.go index 9a015248..3d4d84a7 100644 --- a/commons/tenant-manager/client/client.go +++ b/commons/tenant-manager/client/client.go @@ -207,10 +207,10 @@ func NewClient(baseURL string, logger libLog.Logger, opts ...ClientOption) (*Cli } c := &Client{ - baseURL: baseURL, + baseURL: baseURL, httpClient: newDefaultHTTPClient(), - logger: logger, - cacheTTL: defaultCacheTTL, + logger: logger, + cacheTTL: defaultCacheTTL, } for _, opt := range opts { diff --git a/commons/tenant-manager/consumer/multi_tenant_optional_rabbitmq_test.go b/commons/tenant-manager/consumer/multi_tenant_optional_rabbitmq_test.go index b959021d..754e6093 100644 --- a/commons/tenant-manager/consumer/multi_tenant_optional_rabbitmq_test.go +++ b/commons/tenant-manager/consumer/multi_tenant_optional_rabbitmq_test.go @@ -10,9 +10,9 @@ import ( "testing" "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/testutil" + amqp "github.com/rabbitmq/amqp091-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - amqp "github.com/rabbitmq/amqp091-go" "go.uber.org/goleak" ) diff --git a/commons/tenant-manager/mongo/manager.go b/commons/tenant-manager/mongo/manager.go index 90ea46cf..07b38536 100644 --- a/commons/tenant-manager/mongo/manager.go +++ b/commons/tenant-manager/mongo/manager.go @@ -220,12 +220,12 @@ func WithIdleTimeout(d time.Duration) Option { // NewManager creates a new MongoDB connection manager. func NewManager(c *client.Client, service string, opts ...Option) *Manager { p := &Manager{ - client: c, - service: service, - logger: logcompat.New(nil), - connections: make(map[string]*MongoConnection), - databaseNames: make(map[string]string), - lastAccessed: make(map[string]time.Time), + client: c, + service: service, + logger: logcompat.New(nil), + connections: make(map[string]*MongoConnection), + databaseNames: make(map[string]string), + lastAccessed: make(map[string]time.Time), lastConnectionsCheck: make(map[string]time.Time), connectionsCheckInterval: defaultConnectionsCheckInterval, } diff --git a/commons/tenant-manager/mongo/manager_test.go b/commons/tenant-manager/mongo/manager_test.go index 6931cdff..cf5ed3e6 100644 --- a/commons/tenant-manager/mongo/manager_test.go +++ b/commons/tenant-manager/mongo/manager_test.go @@ -1653,10 +1653,10 @@ func TestManager_RevalidateSettings_RecoverFromPanic(t *testing.T) { // Create a manager with nil client to trigger a panic path manager := &Manager{ - logger: logcompat.New(capLogger), - connections: make(map[string]*MongoConnection), - databaseNames: make(map[string]string), - lastAccessed: make(map[string]time.Time), + logger: logcompat.New(capLogger), + connections: make(map[string]*MongoConnection), + databaseNames: make(map[string]string), + lastAccessed: make(map[string]time.Time), lastConnectionsCheck: make(map[string]time.Time), connectionsCheckInterval: 1 * time.Millisecond, } diff --git a/commons/tenant-manager/postgres/manager.go b/commons/tenant-manager/postgres/manager.go index 39fce8f7..5c60fe48 100644 --- a/commons/tenant-manager/postgres/manager.go +++ b/commons/tenant-manager/postgres/manager.go @@ -107,9 +107,9 @@ type Manager struct { idleTimeout time.Duration // how long before a connection is eligible for eviction lastAccessed map[string]time.Time // LRU tracking per tenant - lastConnectionsCheck map[string]time.Time // tracks per-tenant last settings revalidation time - connectionsCheckInterval time.Duration // configurable interval between settings revalidation checks - lastAppliedSettings map[string]appliedSettings // tracks previously applied pool settings per tenant for change detection + lastConnectionsCheck map[string]time.Time // tracks per-tenant last settings revalidation time + connectionsCheckInterval time.Duration // configurable interval between settings revalidation checks + lastAppliedSettings map[string]appliedSettings // tracks previously applied pool settings per tenant for change detection // revalidateWG tracks in-flight revalidatePoolSettings goroutines so Close() // can wait for them to finish before returning. Without this, goroutines @@ -276,18 +276,18 @@ func WithIdleTimeout(d time.Duration) Option { // NewManager creates a new PostgreSQL connection manager. func NewManager(c *client.Client, service string, opts ...Option) *Manager { p := &Manager{ - client: c, - service: service, - logger: logcompat.New(nil), - connections: make(map[string]*PostgresConnection), - lastAccessed: make(map[string]time.Time), + client: c, + service: service, + logger: logcompat.New(nil), + connections: make(map[string]*PostgresConnection), + lastAccessed: make(map[string]time.Time), lastConnectionsCheck: make(map[string]time.Time), - lastAppliedSettings: make(map[string]appliedSettings), + lastAppliedSettings: make(map[string]appliedSettings), connectionsCheckInterval: defaultConnectionsCheckInterval, - maxOpenConns: fallbackMaxOpenConns, - maxIdleConns: fallbackMaxIdleConns, - maxAllowedOpenConns: defaultMaxAllowedOpenConns, - maxAllowedIdleConns: defaultMaxAllowedIdleConns, + maxOpenConns: fallbackMaxOpenConns, + maxIdleConns: fallbackMaxIdleConns, + maxAllowedOpenConns: defaultMaxAllowedOpenConns, + maxAllowedIdleConns: defaultMaxAllowedIdleConns, } for _, opt := range opts { diff --git a/commons/tenant-manager/rabbitmq/manager.go b/commons/tenant-manager/rabbitmq/manager.go index fcd9b9a2..360c4773 100644 --- a/commons/tenant-manager/rabbitmq/manager.go +++ b/commons/tenant-manager/rabbitmq/manager.go @@ -132,12 +132,12 @@ func WithTLS() Option { // - opts: Optional configuration options func NewManager(c *client.Client, service string, opts ...Option) *Manager { p := &Manager{ - client: c, - service: service, - logger: logcompat.New(nil), - connections: make(map[string]*amqp.Connection), - cachedURIs: make(map[string]string), - lastAccessed: make(map[string]time.Time), + client: c, + service: service, + logger: logcompat.New(nil), + connections: make(map[string]*amqp.Connection), + cachedURIs: make(map[string]string), + lastAccessed: make(map[string]time.Time), lastConnectionsCheck: make(map[string]time.Time), connectionsCheckInterval: defaultConnectionsCheckInterval, } diff --git a/commons/tenant-manager/rabbitmq/manager_test.go b/commons/tenant-manager/rabbitmq/manager_test.go index 11b01d09..eddd8196 100644 --- a/commons/tenant-manager/rabbitmq/manager_test.go +++ b/commons/tenant-manager/rabbitmq/manager_test.go @@ -716,10 +716,10 @@ func TestManager_RevalidateSettings_RecoverFromPanic(t *testing.T) { // Create a manager with nil client to trigger a panic path manager := &Manager{ - logger: logcompat.New(capLogger), - connections: make(map[string]*amqp.Connection), - cachedURIs: make(map[string]string), - lastAccessed: make(map[string]time.Time), + logger: logcompat.New(capLogger), + connections: make(map[string]*amqp.Connection), + cachedURIs: make(map[string]string), + lastAccessed: make(map[string]time.Time), lastConnectionsCheck: make(map[string]time.Time), connectionsCheckInterval: 1 * time.Millisecond, } diff --git a/commons/webhook/deliverer.go b/commons/webhook/deliverer.go new file mode 100644 index 00000000..bb2374df --- /dev/null +++ b/commons/webhook/deliverer.go @@ -0,0 +1,621 @@ +package webhook + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "crypto/tls" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + + "github.com/LerianStudio/lib-commons/v4/commons/backoff" + "github.com/LerianStudio/lib-commons/v4/commons/log" + "github.com/LerianStudio/lib-commons/v4/commons/runtime" +) + +// Defaults for Deliverer configuration. +const ( + defaultMaxConcurrency = 20 + defaultMaxRetries = 3 + defaultBaseDelay = time.Second + defaultHTTPTimeout = 10 * time.Second + defaultMaxIdleConns = 100 + defaultIdlePerHost = 10 + defaultIdleTimeout = 90 * time.Second +) + +// Deliverer sends webhook events to registered endpoints with SSRF protection, +// HMAC-SHA256 signing, and exponential backoff retries. +// +// Create one with NewDeliverer and reuse it across the service lifetime — +// the internal HTTP client maintains a connection pool. +type Deliverer struct { + lister EndpointLister + logger log.Logger + tracer trace.Tracer + metrics DeliveryMetrics + client *http.Client + decryptor SecretDecryptor + maxConc int + maxRetries int +} + +// Option configures a Deliverer at construction time. +type Option func(*Deliverer) + +// WithLogger attaches a structured logger. Nil values are ignored. +func WithLogger(l log.Logger) Option { + return func(d *Deliverer) { + if l != nil { + d.logger = l + } + } +} + +// WithTracer attaches an OpenTelemetry tracer for span creation. Nil values are ignored. +func WithTracer(t trace.Tracer) Option { + return func(d *Deliverer) { + if t != nil { + d.tracer = t + } + } +} + +// WithMetrics attaches a metrics recorder for delivery outcomes. Nil values are ignored. +func WithMetrics(m DeliveryMetrics) Option { + return func(d *Deliverer) { + if m != nil { + d.metrics = m + } + } +} + +// WithMaxConcurrency sets the maximum number of concurrent endpoint deliveries. +// Values ≤ 0 are ignored and the default (20) is used. +func WithMaxConcurrency(n int) Option { + return func(d *Deliverer) { + if n > 0 { + d.maxConc = n + } + } +} + +// WithMaxRetries sets the maximum number of retry attempts per endpoint. +// Values ≤ 0 are ignored and the default (3) is used. +func WithMaxRetries(n int) Option { + return func(d *Deliverer) { + if n > 0 { + d.maxRetries = n + } + } +} + +// WithHTTPClient replaces the default HTTP client. Use this to customize +// timeouts, TLS configuration, or proxy settings. Redirect blocking is +// always enforced regardless of the provided client's CheckRedirect +// setting to preserve SSRF protection. +func WithHTTPClient(c *http.Client) Option { + return func(d *Deliverer) { + if c != nil { + clone := *c + clone.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + } + d.client = &clone + } + } +} + +// WithSecretDecryptor sets a function for decrypting endpoint secrets that +// carry the "enc:" prefix. When nil, encrypted secrets cause delivery to +// be skipped with an error (fail-closed). +func WithSecretDecryptor(fn SecretDecryptor) Option { + return func(d *Deliverer) { + d.decryptor = fn + } +} + +// defaultHTTPClient creates an http.Client optimized for webhook delivery. +// Connection pooling avoids TCP+TLS handshake overhead on repeated deliveries +// to the same endpoint — critical at scale where hundreds of webhooks per +// second would otherwise exhaust ephemeral ports and TLS session caches. +func defaultHTTPClient() *http.Client { + return &http.Client{ + Timeout: defaultHTTPTimeout, + Transport: &http.Transport{ + MaxIdleConns: defaultMaxIdleConns, + MaxIdleConnsPerHost: defaultIdlePerHost, + IdleConnTimeout: defaultIdleTimeout, + }, + // Block all redirects. Webhook endpoints must respond directly — following + // redirects would bypass the SSRF pre-check on the initial URL, allowing + // an attacker to 302 to internal addresses (e.g., cloud metadata services). + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + } +} + +// NewDeliverer creates a webhook deliverer that loads endpoints from lister. +// Functional options configure logging, tracing, concurrency, and retries. +// Returns nil when lister is nil — callers that hold a nil *Deliverer are +// safe because Deliver() and DeliverWithResults() already guard against a +// nil receiver and return ErrNilDeliverer / nil respectively. +func NewDeliverer(lister EndpointLister, opts ...Option) *Deliverer { + if lister == nil { + return nil + } + + d := &Deliverer{ + lister: lister, + logger: log.NewNop(), + client: defaultHTTPClient(), + maxConc: defaultMaxConcurrency, + maxRetries: defaultMaxRetries, + } + + for _, opt := range opts { + if opt != nil { + opt(d) + } + } + + return d +} + +// Deliver sends the event to all active endpoints concurrently. +// It returns an error only for pre-flight failures (nil deliverer, nil event, +// endpoint listing errors). Individual endpoint delivery failures are logged +// and recorded via metrics but do not cause Deliver to return an error. +func (d *Deliverer) Deliver(ctx context.Context, event *Event) error { + if d == nil { + return ErrNilDeliverer + } + + if event == nil { + return errors.New("webhook: nil event") + } + + ctx, span := d.startSpan(ctx, "webhook.Deliver", + attribute.String("webhook.event_type", event.Type), + ) + defer span.End() + + endpoints, err := d.lister.ListActiveEndpoints(ctx) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, "endpoint listing failed") + + return fmt.Errorf("webhook: list endpoints: %w", err) + } + + active := filterActive(endpoints) + if len(active) == 0 { + d.log(ctx, log.LevelDebug, "no active endpoints for event", + log.String("event_type", event.Type), + ) + + return nil + } + + d.fanOut(ctx, active, event) + + return nil +} + +// DeliverWithResults sends the event to all active endpoints and returns +// per-endpoint delivery results. Useful for callers that need to inspect +// or persist individual outcomes. +func (d *Deliverer) DeliverWithResults(ctx context.Context, event *Event) []DeliveryResult { + if d == nil || event == nil { + return nil + } + + ctx, span := d.startSpan(ctx, "webhook.DeliverWithResults", + attribute.String("webhook.event_type", event.Type), + ) + defer span.End() + + endpoints, err := d.lister.ListActiveEndpoints(ctx) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, "endpoint listing failed") + + return []DeliveryResult{{ + Error: fmt.Errorf("webhook: list endpoints: %w", err), + }} + } + + active := filterActive(endpoints) + if len(active) == 0 { + return nil + } + + return d.fanOutWithResults(ctx, active, event) +} + +// fanOut delivers to all endpoints concurrently, capped by the semaphore. +// Individual failures are logged but not collected. The call blocks until +// every goroutine has completed, preventing orphaned goroutines during +// graceful shutdown. +func (d *Deliverer) fanOut(ctx context.Context, endpoints []Endpoint, event *Event) { + sem := make(chan struct{}, d.maxConc) + + var wg sync.WaitGroup + + for i := range endpoints { + ep := endpoints[i] + + wg.Add(1) + + sem <- struct{}{} + + dlvCtx := context.WithoutCancel(ctx) + + go func() { + defer wg.Done() + defer func() { <-sem }() + defer runtime.RecoverWithPolicyAndContext( + dlvCtx, d.logger, "webhook", "deliver-to-"+ep.ID, runtime.KeepRunning, + ) + + d.deliverToEndpoint(dlvCtx, ep, event) + }() + } + + wg.Wait() +} + +// fanOutWithResults delivers to all endpoints and collects per-endpoint results. +func (d *Deliverer) fanOutWithResults( + ctx context.Context, + endpoints []Endpoint, + event *Event, +) []DeliveryResult { + sem := make(chan struct{}, d.maxConc) + results := make([]DeliveryResult, len(endpoints)) + + var wg sync.WaitGroup + + for i := range endpoints { + ep := endpoints[i] + idx := i + + wg.Add(1) + + sem <- struct{}{} + + dlvCtx := context.WithoutCancel(ctx) + + // Pre-populate so callers always see which endpoint was attempted, + // even if deliverToEndpoint panics before writing the result. + results[idx] = DeliveryResult{EndpointID: ep.ID} + + go func() { + defer wg.Done() + defer func() { <-sem }() + defer runtime.RecoverWithPolicyAndContext( + dlvCtx, d.logger, "webhook", "deliver-to-"+ep.ID, runtime.KeepRunning, + ) + + results[idx] = d.deliverToEndpoint(dlvCtx, ep, event) + }() + } + + wg.Wait() + + return results +} + +// deliverToEndpoint performs the SSRF check, DNS pinning, and retry loop +// for a single endpoint. Returns the delivery result. +func (d *Deliverer) deliverToEndpoint( + ctx context.Context, + ep Endpoint, + event *Event, +) DeliveryResult { + ctx, span := d.startSpan(ctx, "webhook.DeliverToEndpoint", + attribute.String("webhook.endpoint_id", ep.ID), + ) + defer span.End() + + result := DeliveryResult{EndpointID: ep.ID} + + // --- SSRF validation + DNS pinning (single lookup, eliminates TOCTOU) --- + pinnedURL, originalHost, ssrfErr := resolveAndValidateIP(ctx, ep.URL) + if ssrfErr != nil { + span.RecordError(ssrfErr) + span.SetStatus(codes.Error, "SSRF blocked") + + d.log(ctx, log.LevelError, "webhook delivery blocked by SSRF check", + log.String("url", sanitizeURL(ep.URL)), + log.Err(ssrfErr), + ) + + result.Error = fmt.Errorf("%w: %w", ErrSSRFBlocked, ssrfErr) + + return result + } + + // --- Resolve signing secret once before the retry loop --- + secret, secretErr := d.resolveSecret(ep.Secret) + if secretErr != nil { + span.RecordError(secretErr) + span.SetStatus(codes.Error, "secret decryption failed") + + d.log(ctx, log.LevelError, "webhook secret decryption failed, skipping delivery", + log.String("endpoint_id", ep.ID), + log.Err(secretErr), + ) + + result.Error = secretErr + + return result + } + + // --- Retry loop --- + for attempt := range d.maxRetries + 1 { + result.Attempts = attempt + 1 + + if attempt > 0 { + delay := backoff.ExponentialWithJitter(defaultBaseDelay, attempt-1) + if err := backoff.WaitContext(ctx, delay); err != nil { + result.Error = fmt.Errorf("webhook: context cancelled during backoff: %w", err) + + return result + } + } + + statusCode, err := d.doHTTP(ctx, pinnedURL, originalHost, event, secret) + result.StatusCode = statusCode + + if err != nil { + d.log(ctx, log.LevelWarn, "webhook delivery failed", + log.String("url", sanitizeURL(ep.URL)), + log.Int("attempt", attempt+1), + log.Err(err), + ) + + continue + } + + if statusCode >= http.StatusOK && statusCode < http.StatusMultipleChoices { + result.Success = true + d.recordMetrics(ctx, ep.ID, true, statusCode, result.Attempts) + + d.log(ctx, log.LevelInfo, "webhook delivered", + log.String("url", sanitizeURL(ep.URL)), + log.String("event_type", event.Type), + log.Int("status", statusCode), + ) + + return result + } + + // Non-retryable client errors — break immediately (except 429 Too Many Requests). + if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError && statusCode != http.StatusTooManyRequests { + result.Error = fmt.Errorf("webhook: non-retryable status %d", statusCode) + d.recordMetrics(ctx, ep.ID, false, statusCode, attempt+1) + + return result + } + + d.log(ctx, log.LevelWarn, "webhook non-2xx response", + log.String("url", sanitizeURL(ep.URL)), + log.Int("status", statusCode), + log.Int("attempt", attempt+1), + ) + } + + // Exhausted all retries. + result.Error = fmt.Errorf("%w: exhausted %d attempts for %s", ErrDeliveryFailed, d.maxRetries+1, sanitizeURL(ep.URL)) + d.recordMetrics(ctx, ep.ID, false, result.StatusCode, result.Attempts) + + span.RecordError(result.Error) + span.SetStatus(codes.Error, "delivery exhausted retries") + + d.log(ctx, log.LevelError, "webhook delivery exhausted retries", + log.String("url", sanitizeURL(ep.URL)), + log.String("event_type", event.Type), + log.Int("attempts", result.Attempts), + ) + + return result +} + +// doHTTP builds and executes a single HTTP request to the (possibly pinned) URL. +// Returns the status code and any transport-level error. +func (d *Deliverer) doHTTP( + ctx context.Context, + pinnedURL string, + originalHost string, + event *Event, + secret string, +) (int, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, pinnedURL, bytes.NewReader(event.Payload)) + if err != nil { + return 0, fmt.Errorf("webhook: build request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Webhook-Event", event.Type) + req.Header.Set("X-Webhook-Timestamp", strconv.FormatInt(event.Timestamp, 10)) + + // When the URL was rewritten to use the pinned IP, set the Host header + // for virtual hosting and use TLSClientConfig.ServerName so TLS SNI and + // certificate verification use the original hostname (not the IP). + if originalHost != "" { + req.Host = originalHost + } + + if secret != "" { + sig := computeHMAC(event.Payload, secret) + req.Header.Set("X-Webhook-Signature", "sha256="+sig) + } + + client := d.client + + if originalHost != "" && strings.HasPrefix(pinnedURL, "https://") { + client = d.httpsClientForPinnedIP(originalHost) + } + + resp, err := client.Do(req) + if err != nil { + return 0, fmt.Errorf("webhook: http request: %w", err) + } + + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 64<<10)) + _ = resp.Body.Close() + + return resp.StatusCode, nil +} + +// resolveSecret decrypts the endpoint secret if it carries the "enc:" prefix. +// Plaintext secrets and empty strings pass through unchanged. +func (d *Deliverer) resolveSecret(raw string) (string, error) { + if raw == "" { + return "", nil + } + + if !strings.HasPrefix(raw, "enc:") { + return raw, nil + } + + if d.decryptor == nil { + return "", errors.New("webhook: encrypted secret but no decryptor configured") + } + + plaintext, err := d.decryptor(raw[4:]) // strip "enc:" prefix + if err != nil { + return "", fmt.Errorf("webhook: decrypt secret: %w", err) + } + + return plaintext, nil +} + +// computeHMAC returns the hex-encoded HMAC-SHA256 of payload using the given secret. +// +// Design note — timestamp not included in signature (by intent): +// The signature covers the raw payload only, not the X-Webhook-Timestamp value. +// Some industry integrations (e.g., Stripe) sign "timestamp.payload" to bind the +// timestamp to the signature and prevent replay attacks. We intentionally do not +// do this because: (a) changing the input format would silently break all existing +// consumers who already verify "sha256=HMAC(payload)" and (b) replay protection +// at the application layer (e.g., idempotency keys, short timestamp windows) is +// the responsibility of the receiving service. Receivers who need replay protection +// should validate that X-Webhook-Timestamp is within an acceptable window (e.g., +// ±5 minutes) independently of the signature check. +func computeHMAC(payload []byte, secret string) string { + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(payload) + + return hex.EncodeToString(mac.Sum(nil)) +} + +// filterActive returns only endpoints where Active is true. +func filterActive(endpoints []Endpoint) []Endpoint { + active := make([]Endpoint, 0, len(endpoints)) + + for i := range endpoints { + if endpoints[i].Active { + active = append(active, endpoints[i]) + } + } + + return active +} + +// startSpan creates an OTel span if a tracer is configured, or returns a +// no-op span otherwise. +func (d *Deliverer) startSpan( + ctx context.Context, + name string, + attrs ...attribute.KeyValue, +) (context.Context, trace.Span) { + if d.tracer == nil { + return ctx, trace.SpanFromContext(ctx) + } + + ctx, span := d.tracer.Start(ctx, name, trace.WithAttributes(attrs...)) //nolint:spancheck // span.End is called by the caller + + return ctx, span //nolint:spancheck // callers defer span.End() immediately after startSpan +} + +// log emits a structured log entry if a logger is configured. +func (d *Deliverer) log(ctx context.Context, level log.Level, msg string, fields ...log.Field) { + if d.logger == nil { + return + } + + d.logger.Log(ctx, level, msg, fields...) +} + +// recordMetrics delegates to the configured DeliveryMetrics, if any. +func (d *Deliverer) recordMetrics(ctx context.Context, endpointID string, success bool, statusCode, attempts int) { + if d.metrics == nil { + return + } + + d.metrics.RecordDelivery(ctx, endpointID, success, statusCode, attempts) +} + +// httpsClientForPinnedIP returns an HTTP client whose TLS config uses the +// given hostname for SNI and certificate verification. This is necessary +// when the request URL has been rewritten to an IP address for DNS pinning +// (SSRF protection) — without this, Go would try to verify the TLS cert +// against the IP, which fails for hostname-based certificates. +func (d *Deliverer) httpsClientForPinnedIP(originalHost string) *http.Client { + baseTransport := d.client.Transport + if baseTransport == nil { + baseTransport = http.DefaultTransport + } + + transport, ok := baseTransport.(*http.Transport) + if !ok { + // Non-standard transport — fall back to the default client and let + // the caller's transport handle TLS. + return d.client + } + + pinned := transport.Clone() + if pinned.TLSClientConfig == nil { + pinned.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } else if pinned.TLSClientConfig.MinVersion < tls.VersionTLS12 { + pinned.TLSClientConfig.MinVersion = tls.VersionTLS12 + } + + pinned.TLSClientConfig.ServerName = originalHost + + clone := *d.client + clone.Transport = pinned + + return &clone +} + +// sanitizeURL strips query parameters from a URL before logging to prevent +// credential leakage. Webhook URLs may carry tokens in query params +// (e.g., ?token=..., ?api_key=...) that must not appear in log output. +// On parse failure the raw string is returned unchanged so no log line is lost. +func sanitizeURL(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return rawURL + } + + u.RawQuery = "" + + return u.String() +} diff --git a/commons/webhook/deliverer_test.go b/commons/webhook/deliverer_test.go new file mode 100644 index 00000000..ef4fe7dc --- /dev/null +++ b/commons/webhook/deliverer_test.go @@ -0,0 +1,902 @@ +//go:build unit + +package webhook + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Test helpers +// --------------------------------------------------------------------------- + +// mockLister implements EndpointLister for testing. +type mockLister struct { + endpoints []Endpoint + err error +} + +func (m *mockLister) ListActiveEndpoints(_ context.Context) ([]Endpoint, error) { + return m.endpoints, m.err +} + +// mockMetrics implements DeliveryMetrics for testing. +type mockMetrics struct { + mu sync.Mutex + calls []metricCall +} + +type metricCall struct { + EndpointID string + Success bool + StatusCode int + Attempts int +} + +func (m *mockMetrics) RecordDelivery(_ context.Context, endpointID string, success bool, statusCode int, attempts int) { + m.mu.Lock() + defer m.mu.Unlock() + + m.calls = append(m.calls, metricCall{endpointID, success, statusCode, attempts}) +} + +func (m *mockMetrics) getCalls() []metricCall { + m.mu.Lock() + defer m.mu.Unlock() + + cp := make([]metricCall, len(m.calls)) + copy(cp, m.calls) + + return cp +} + +// newTestEvent creates a canonical event for tests. +func newTestEvent() *Event { + return &Event{ + Type: "order.created", + Payload: []byte(`{"id":"123"}`), + Timestamp: 1700000000, + } +} + +// ssrfBypassClient returns an http.Client whose transport dials directly to +// the listener address, regardless of the URL hostname. This lets us put a +// publicly-routable hostname in the endpoint URL (bypassing SSRF validation) +// while actually connecting to the httptest server on 127.0.0.1. +func ssrfBypassClient(listenAddr string) *http.Client { + return &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + DialContext: func(_ context.Context, network, _ string) (net.Conn, error) { + return net.Dial(network, listenAddr) + }, + }, + } +} + +// startTestServer returns an httptest.Server with the handler and the +// "fake" public URL that points to 93.184.216.34 (example.com) but on the +// server's actual port. Tests use ssrfBypassClient to connect. +func startTestServer(t *testing.T, handler http.Handler) (*httptest.Server, string) { + t.Helper() + + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + + // Extract port from the test server (which listens on 127.0.0.1:PORT). + _, port, err := net.SplitHostPort(srv.Listener.Addr().String()) + require.NoError(t, err) + + // Build a URL with a public hostname so validateResolvedIP doesn't block it. + // The ssrfBypassClient transport dials the real server regardless. + publicURL := "http://example.com:" + port + + return srv, publicURL +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +func TestDefaultHTTPClient_BlocksRedirects(t *testing.T) { + t.Parallel() + + client := defaultHTTPClient() + require.NotNil(t, client.CheckRedirect, "CheckRedirect must be set to block redirects") + + // Simulate a redirect: CheckRedirect should return http.ErrUseLastResponse. + err := client.CheckRedirect(nil, nil) + assert.Equal(t, http.ErrUseLastResponse, err, + "CheckRedirect should return http.ErrUseLastResponse to block redirect-following") +} + +func TestDeliver_RedirectNotFollowed(t *testing.T) { + t.Parallel() + + // Server that always redirects — the deliverer must treat this as a non-2xx + // response and not follow the redirect. + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "http://169.254.169.254/latest/meta-data/", http.StatusFound) + })) + + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-redir", URL: pubURL, Secret: "", Active: true}, + }, + } + + // Use ssrfBypassClient but add the CheckRedirect policy from defaultHTTPClient. + client := ssrfBypassClient(srv.Listener.Addr().String()) + client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + } + + d := NewDeliverer(lister, + WithMaxRetries(0), + WithHTTPClient(client), + ) + + results := d.DeliverWithResults(context.Background(), newTestEvent()) + require.Len(t, results, 1) + + // The redirect (302) is a non-2xx status, so the delivery should fail. + assert.False(t, results[0].Success, "redirect should not be followed, resulting in non-2xx") + assert.Equal(t, http.StatusFound, results[0].StatusCode, + "status code should be the redirect status, not the target's status") +} + +func TestNewDeliverer_Defaults(t *testing.T) { + t.Parallel() + + lister := &mockLister{} + d := NewDeliverer(lister) + + require.NotNil(t, d) + assert.Equal(t, 20, d.maxConc, "default maxConcurrency should be 20") + assert.Equal(t, 3, d.maxRetries, "default maxRetries should be 3") + assert.NotNil(t, d.client, "default HTTP client should not be nil") + assert.NotNil(t, d.logger, "logger should default to nop logger") + assert.Nil(t, d.tracer, "tracer should be nil when not set") + assert.Nil(t, d.metrics, "metrics should be nil when not set") + assert.Nil(t, d.decryptor, "decryptor should be nil when not set") +} + +func TestNewDeliverer_NilLister(t *testing.T) { + t.Parallel() + + d := NewDeliverer(nil) + assert.Nil(t, d, "NewDeliverer should return nil when lister is nil") + + // Nil *Deliverer is safe to use — Deliver returns ErrNilDeliverer. + err := d.Deliver(context.Background(), newTestEvent()) + require.ErrorIs(t, err, ErrNilDeliverer) + + // DeliverWithResults returns nil slice. + results := d.DeliverWithResults(context.Background(), newTestEvent()) + assert.Nil(t, results) +} + +func TestDeliver_NilDeliverer(t *testing.T) { + t.Parallel() + + var d *Deliverer + + err := d.Deliver(context.Background(), newTestEvent()) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrNilDeliverer) +} + +func TestDeliver_NilEvent(t *testing.T) { + t.Parallel() + + lister := &mockLister{} + d := NewDeliverer(lister) + + err := d.Deliver(context.Background(), nil) + + require.Error(t, err) + assert.Contains(t, err.Error(), "nil event") +} + +func TestDeliver_NoActiveEndpoints(t *testing.T) { + t.Parallel() + + lister := &mockLister{endpoints: []Endpoint{}} + d := NewDeliverer(lister) + + err := d.Deliver(context.Background(), newTestEvent()) + + assert.NoError(t, err, "empty endpoint list should not be an error") +} + +func TestDeliver_ListerError(t *testing.T) { + t.Parallel() + + listErr := errors.New("database connection refused") + lister := &mockLister{err: listErr} + d := NewDeliverer(lister) + + err := d.Deliver(context.Background(), newTestEvent()) + + require.Error(t, err) + assert.ErrorIs(t, err, listErr, "underlying lister error should be wrapped") + assert.Contains(t, err.Error(), "list endpoints") +} + +func TestDeliver_Success(t *testing.T) { + t.Parallel() + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + metrics := &mockMetrics{} + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-1", URL: pubURL, Secret: "test-secret", Active: true}, + }, + } + + d := NewDeliverer(lister, + WithMetrics(metrics), + WithMaxRetries(1), + WithHTTPClient(ssrfBypassClient(srv.Listener.Addr().String())), + ) + + err := d.Deliver(context.Background(), newTestEvent()) + require.NoError(t, err) + + // fanOut is async — wait for the metrics recording. + require.Eventually(t, func() bool { + return len(metrics.getCalls()) > 0 + }, 2*time.Second, 10*time.Millisecond) + + calls := metrics.getCalls() + require.Len(t, calls, 1) + assert.True(t, calls[0].Success) + assert.Equal(t, http.StatusOK, calls[0].StatusCode) + assert.Equal(t, "ep-1", calls[0].EndpointID) +} + +func TestDeliver_Retries(t *testing.T) { + t.Parallel() + + var attempts atomic.Int32 + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := attempts.Add(1) + if n < 3 { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + })) + + metrics := &mockMetrics{} + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-retry", URL: pubURL, Secret: "s", Active: true}, + }, + } + + // maxRetries=3 means 4 total attempts (initial + 3 retries). + // Server succeeds on attempt 3. + d := NewDeliverer(lister, + WithMetrics(metrics), + WithMaxRetries(3), + WithHTTPClient(ssrfBypassClient(srv.Listener.Addr().String())), + ) + + results := d.DeliverWithResults(context.Background(), newTestEvent()) + require.Len(t, results, 1) + assert.True(t, results[0].Success) + assert.Equal(t, 3, results[0].Attempts, "should have taken 3 attempts") +} + +func TestDeliver_ExhaustedRetries(t *testing.T) { + t.Parallel() + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + })) + + metrics := &mockMetrics{} + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-fail", URL: pubURL, Secret: "s", Active: true}, + }, + } + + d := NewDeliverer(lister, + WithMetrics(metrics), + WithMaxRetries(1), + WithHTTPClient(ssrfBypassClient(srv.Listener.Addr().String())), + ) + + results := d.DeliverWithResults(context.Background(), newTestEvent()) + require.Len(t, results, 1) + assert.False(t, results[0].Success) + assert.ErrorIs(t, results[0].Error, ErrDeliveryFailed) + assert.Equal(t, 2, results[0].Attempts, "initial + 1 retry = 2") +} + +func TestDeliver_HMACSignature(t *testing.T) { + t.Parallel() + + secret := "my-webhook-secret" + event := newTestEvent() + + var gotSig string + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSig = r.Header.Get("X-Webhook-Signature") + w.WriteHeader(http.StatusOK) + })) + + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-hmac", URL: pubURL, Secret: secret, Active: true}, + }, + } + + d := NewDeliverer(lister, + WithMaxRetries(0), + WithHTTPClient(ssrfBypassClient(srv.Listener.Addr().String())), + ) + + results := d.DeliverWithResults(context.Background(), event) + require.Len(t, results, 1) + require.True(t, results[0].Success) + + // Compute expected HMAC independently. + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(event.Payload) + expected := "sha256=" + hex.EncodeToString(mac.Sum(nil)) + + assert.Equal(t, expected, gotSig, "HMAC signature should match") +} + +func TestDeliver_Headers(t *testing.T) { + t.Parallel() + + event := newTestEvent() + + var ( + gotContentType string + gotEventType string + gotTimestamp string + ) + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotContentType = r.Header.Get("Content-Type") + gotEventType = r.Header.Get("X-Webhook-Event") + gotTimestamp = r.Header.Get("X-Webhook-Timestamp") + w.WriteHeader(http.StatusOK) + })) + + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-hdr", URL: pubURL, Secret: "", Active: true}, + }, + } + + d := NewDeliverer(lister, + WithMaxRetries(0), + WithHTTPClient(ssrfBypassClient(srv.Listener.Addr().String())), + ) + + results := d.DeliverWithResults(context.Background(), event) + require.Len(t, results, 1) + require.True(t, results[0].Success) + + assert.Equal(t, "application/json", gotContentType) + assert.Equal(t, event.Type, gotEventType) + assert.Equal(t, strconv.FormatInt(event.Timestamp, 10), gotTimestamp) +} + +func TestDeliver_EncryptedSecret_WithDecryptor(t *testing.T) { + t.Parallel() + + plainSecret := "decrypted-secret" + event := newTestEvent() + + var gotSig string + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSig = r.Header.Get("X-Webhook-Signature") + w.WriteHeader(http.StatusOK) + })) + + decryptor := func(ciphertext string) (string, error) { + if ciphertext == "abc123" { + return plainSecret, nil + } + + return "", errors.New("unknown ciphertext") + } + + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-enc", URL: pubURL, Secret: "enc:abc123", Active: true}, + }, + } + + d := NewDeliverer(lister, + WithMaxRetries(0), + WithSecretDecryptor(decryptor), + WithHTTPClient(ssrfBypassClient(srv.Listener.Addr().String())), + ) + + results := d.DeliverWithResults(context.Background(), event) + require.Len(t, results, 1) + require.True(t, results[0].Success) + + // Verify HMAC was computed with the *decrypted* secret. + mac := hmac.New(sha256.New, []byte(plainSecret)) + mac.Write(event.Payload) + expected := "sha256=" + hex.EncodeToString(mac.Sum(nil)) + + assert.Equal(t, expected, gotSig) +} + +func TestDeliver_EncryptedSecret_NoDecryptor(t *testing.T) { + t.Parallel() + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + _ = srv // keep server alive via t.Cleanup + + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-no-dec", URL: pubURL, Secret: "enc:ciphertext", Active: true}, + }, + } + + // No decryptor configured — fail-closed: delivery should abort. + d := NewDeliverer(lister, WithMaxRetries(0)) + + results := d.DeliverWithResults(context.Background(), newTestEvent()) + require.Len(t, results, 1) + assert.False(t, results[0].Success) + require.Error(t, results[0].Error) + assert.Contains(t, results[0].Error.Error(), "no decryptor configured") +} + +func TestDeliverWithResults_ReturnsPerEndpoint(t *testing.T) { + t.Parallel() + + // Server 1: always succeeds. + srv1, pubURL1 := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Server 2: always fails. + srv2, pubURL2 := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + })) + + // Both servers need to be reachable via the same client. Since each + // ssrfBypassClient pins all dials to a single address, we create a + // mux-style transport that dispatches by port. + addr1 := srv1.Listener.Addr().String() + addr2 := srv2.Listener.Addr().String() + + _, port1, _ := net.SplitHostPort(addr1) + _, port2, _ := net.SplitHostPort(addr2) + + client := &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + _, p, _ := net.SplitHostPort(addr) + switch p { + case port1: + return net.Dial(network, addr1) + case port2: + return net.Dial(network, addr2) + default: + return net.Dial(network, addr) + } + }, + }, + } + + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "good", URL: pubURL1, Secret: "s1", Active: true}, + {ID: "bad", URL: pubURL2, Secret: "s2", Active: true}, + }, + } + + d := NewDeliverer(lister, WithMaxRetries(0), WithHTTPClient(client)) + + results := d.DeliverWithResults(context.Background(), newTestEvent()) + require.Len(t, results, 2) + + // Build a map for order-independent assertions. + byID := make(map[string]DeliveryResult, len(results)) + for _, r := range results { + byID[r.EndpointID] = r + } + + assert.True(t, byID["good"].Success) + assert.False(t, byID["bad"].Success) + assert.ErrorIs(t, byID["bad"].Error, ErrDeliveryFailed) +} + +func TestDeliver_InactiveEndpoints_Skipped(t *testing.T) { + t.Parallel() + + var called atomic.Int32 + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called.Add(1) + w.WriteHeader(http.StatusOK) + })) + + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "active", URL: pubURL, Secret: "", Active: true}, + {ID: "inactive", URL: pubURL, Secret: "", Active: false}, + }, + } + + d := NewDeliverer(lister, + WithMaxRetries(0), + WithHTTPClient(ssrfBypassClient(srv.Listener.Addr().String())), + ) + + results := d.DeliverWithResults(context.Background(), newTestEvent()) + + // Only the active endpoint should produce a result. + require.Len(t, results, 1) + assert.Equal(t, "active", results[0].EndpointID) + assert.True(t, results[0].Success) +} + +func TestComputeHMAC(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + payload []byte + secret string + want string + }{ + { + name: "known vector", + payload: []byte(`{"id":"123"}`), + secret: "secret", + want: func() string { + mac := hmac.New(sha256.New, []byte("secret")) + mac.Write([]byte(`{"id":"123"}`)) + return hex.EncodeToString(mac.Sum(nil)) + }(), + }, + { + name: "empty payload", + payload: []byte{}, + secret: "key", + want: func() string { + mac := hmac.New(sha256.New, []byte("key")) + mac.Write([]byte{}) + return hex.EncodeToString(mac.Sum(nil)) + }(), + }, + { + name: "empty secret", + payload: []byte("data"), + secret: "", + want: func() string { + mac := hmac.New(sha256.New, []byte("")) + mac.Write([]byte("data")) + return hex.EncodeToString(mac.Sum(nil)) + }(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := computeHMAC(tt.payload, tt.secret) + assert.Equal(t, tt.want, got) + }) + } +} + +// --------------------------------------------------------------------------- +// Static HMAC test vector — externally verified hex digest +// --------------------------------------------------------------------------- + +// TestComputeHMAC_StaticVector verifies computeHMAC against a known reference +// value produced by: +// +// echo -n 'test-payload' | openssl dgst -sha256 -hmac 'test-secret' | awk '{print $2}' +// → 5b12467d7c448555779e70d76204105c67d27d1c991f3080c19732f9ac1988ef +func TestComputeHMAC_StaticVector(t *testing.T) { + t.Parallel() + + const ( + payload = "test-payload" + secret = "test-secret" + expected = "5b12467d7c448555779e70d76204105c67d27d1c991f3080c19732f9ac1988ef" + ) + + got := computeHMAC([]byte(payload), secret) + assert.Equal(t, expected, got, + "HMAC-SHA256 of %q with key %q must match externally verified reference hex", payload, secret) +} + +// --------------------------------------------------------------------------- +// Non-retryable 4xx responses — break immediately, do not exhaust retries +// --------------------------------------------------------------------------- + +// TestDeliver_NonRetryable4xx verifies that when a server always returns 400 +// (Bad Request), the deliverer does NOT retry — it returns immediately after +// the first attempt, even when maxRetries is set higher. +func TestDeliver_NonRetryable4xx(t *testing.T) { + t.Parallel() + + var attempts atomic.Int32 + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + attempts.Add(1) + w.WriteHeader(http.StatusBadRequest) + })) + + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-400", URL: pubURL, Secret: "", Active: true}, + }, + } + + d := NewDeliverer(lister, + WithMaxRetries(3), // 3 retries configured, but 4xx must short-circuit + WithHTTPClient(ssrfBypassClient(srv.Listener.Addr().String())), + ) + + results := d.DeliverWithResults(context.Background(), newTestEvent()) + require.Len(t, results, 1) + + assert.False(t, results[0].Success, "400 response must not be considered successful") + assert.Equal(t, http.StatusBadRequest, results[0].StatusCode) + assert.Equal(t, 1, int(attempts.Load()), + "only 1 HTTP attempt must be made for a non-retryable 4xx status") + assert.Equal(t, 1, results[0].Attempts, + "DeliveryResult.Attempts must reflect the single attempt") + require.Error(t, results[0].Error) + assert.Contains(t, results[0].Error.Error(), "non-retryable status 400") +} + +func TestDeliver_BodyIsSent(t *testing.T) { + t.Parallel() + + ev := newTestEvent() + + var gotBody []byte + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err == nil { + gotBody = body + } + + w.WriteHeader(http.StatusOK) + })) + + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-body", URL: pubURL, Active: true}, + }, + } + + d := NewDeliverer(lister, + WithMaxRetries(0), + WithHTTPClient(ssrfBypassClient(srv.Listener.Addr().String())), + ) + + results := d.DeliverWithResults(context.Background(), ev) + require.Len(t, results, 1) + require.True(t, results[0].Success) + + assert.Equal(t, ev.Payload, gotBody) +} + +func TestDeliver_NoSignature_WhenSecretEmpty(t *testing.T) { + t.Parallel() + + var gotSig string + + srv, pubURL := startTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSig = r.Header.Get("X-Webhook-Signature") + w.WriteHeader(http.StatusOK) + })) + + lister := &mockLister{ + endpoints: []Endpoint{ + {ID: "ep-nosig", URL: pubURL, Secret: "", Active: true}, + }, + } + + d := NewDeliverer(lister, + WithMaxRetries(0), + WithHTTPClient(ssrfBypassClient(srv.Listener.Addr().String())), + ) + + results := d.DeliverWithResults(context.Background(), newTestEvent()) + require.Len(t, results, 1) + require.True(t, results[0].Success) + + assert.Empty(t, gotSig, "no signature header when secret is empty") +} + +func TestWithOptions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts []Option + checkConc int + checkRetry int + }{ + { + name: "custom concurrency and retries", + opts: []Option{WithMaxConcurrency(5), WithMaxRetries(10)}, + checkConc: 5, + checkRetry: 10, + }, + { + name: "zero concurrency uses default", + opts: []Option{WithMaxConcurrency(0)}, + checkConc: defaultMaxConcurrency, + checkRetry: defaultMaxRetries, + }, + { + name: "negative retries uses default", + opts: []Option{WithMaxRetries(-1)}, + checkConc: defaultMaxConcurrency, + checkRetry: defaultMaxRetries, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + d := NewDeliverer(&mockLister{}, tt.opts...) + assert.Equal(t, tt.checkConc, d.maxConc) + assert.Equal(t, tt.checkRetry, d.maxRetries) + }) + } +} + +func TestWithHTTPClient_Nil_KeepsDefault(t *testing.T) { + t.Parallel() + + d := NewDeliverer(&mockLister{}, WithHTTPClient(nil)) + assert.NotNil(t, d.client, "nil http.Client option should keep default") +} + +func TestDeliverWithResults_NilDeliverer(t *testing.T) { + t.Parallel() + + var d *Deliverer + + results := d.DeliverWithResults(context.Background(), newTestEvent()) + assert.Nil(t, results) +} + +func TestDeliverWithResults_NilEvent(t *testing.T) { + t.Parallel() + + d := NewDeliverer(&mockLister{}) + + results := d.DeliverWithResults(context.Background(), nil) + assert.Nil(t, results) +} + +func TestDeliverWithResults_ListerError(t *testing.T) { + t.Parallel() + + listErr := errors.New("db down") + d := NewDeliverer(&mockLister{err: listErr}) + + results := d.DeliverWithResults(context.Background(), newTestEvent()) + require.Len(t, results, 1) + assert.ErrorIs(t, results[0].Error, listErr) +} + +func TestFilterActive(t *testing.T) { + t.Parallel() + + all := []Endpoint{ + {ID: "a", Active: true}, + {ID: "b", Active: false}, + {ID: "c", Active: true}, + } + + active := filterActive(all) + require.Len(t, active, 2) + assert.Equal(t, "a", active[0].ID) + assert.Equal(t, "c", active[1].ID) +} + +func TestResolveSecret(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw string + decryptor SecretDecryptor + want string + wantErr string + }{ + { + name: "empty string passthrough", + raw: "", + want: "", + }, + { + name: "plaintext passthrough", + raw: "my-secret", + want: "my-secret", + }, + { + name: "encrypted with decryptor", + raw: "enc:cipher", + decryptor: func(_ string) (string, error) { return "plain", nil }, + want: "plain", + }, + { + name: "encrypted without decryptor", + raw: "enc:cipher", + wantErr: "no decryptor configured", + }, + { + name: "decryptor returns error", + raw: "enc:bad", + decryptor: func(_ string) (string, error) { + return "", fmt.Errorf("bad ciphertext") + }, + wantErr: "decrypt secret", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + d := &Deliverer{decryptor: tt.decryptor} + + got, err := d.resolveSecret(tt.raw) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} diff --git a/commons/webhook/doc.go b/commons/webhook/doc.go new file mode 100644 index 00000000..ded26825 --- /dev/null +++ b/commons/webhook/doc.go @@ -0,0 +1,30 @@ +// Package webhook provides a secure webhook delivery engine with SSRF protection, +// HMAC-SHA256 payload signing, and exponential backoff retries. +// +// # Security features +// +// - Two-layer SSRF protection: pre-resolution IP validation + DNS-pinned delivery +// - DNS rebinding prevention via resolved IP pinning +// - Blocks private, loopback, and link-local IP ranges +// - HMAC-SHA256 signature in X-Webhook-Signature header (sha256=HEX) +// - URL query parameters stripped from log output to prevent credential leakage +// +// # Delivery model +// +// - Concurrent delivery with configurable semaphore (default: 20 goroutines) +// - Exponential backoff with jitter (1s, 2s, 4s, ...) +// - Per-endpoint retry with configurable max attempts (default: 3) +// +// # HMAC signature scope +// +// The X-Webhook-Signature header carries HMAC-SHA256 computed over the raw +// payload bytes only. X-Webhook-Timestamp is sent as a separate informational +// header but is NOT covered by the HMAC — an attacker who captures a valid +// payload+signature pair can replay it with a fresh timestamp. +// +// Receivers who need replay protection must implement it independently, for +// example by tracking event IDs or embedding a nonce in the payload itself. +// Timestamp-window checks alone are insufficient because the timestamp is +// unsigned. Including the timestamp in the HMAC would be a breaking change +// for existing consumers. +package webhook diff --git a/commons/webhook/errors.go b/commons/webhook/errors.go new file mode 100644 index 00000000..57edcfee --- /dev/null +++ b/commons/webhook/errors.go @@ -0,0 +1,22 @@ +package webhook + +import "errors" + +var ( + // ErrNilDeliverer is returned when a method is called on a nil Deliverer receiver. + ErrNilDeliverer = errors.New("webhook: nil deliverer") + + // ErrSSRFBlocked is returned when an endpoint URL is rejected by SSRF protection. + // This includes private/loopback/link-local/CGNAT/RFC-reserved IP ranges, + // disallowed URL schemes (anything other than http/https), and DNS lookup failures + // (fail-closed). + ErrSSRFBlocked = errors.New("webhook: SSRF blocked") + + // ErrDeliveryFailed is returned when delivery to an endpoint exhausts all retry attempts. + ErrDeliveryFailed = errors.New("webhook: delivery failed") + + // ErrInvalidURL is returned when an endpoint URL fails validation. This covers + // parse errors, empty hostnames, unresolvable DNS (no valid IPs), and other + // structural URL problems. + ErrInvalidURL = errors.New("webhook: invalid URL") +) diff --git a/commons/webhook/ssrf.go b/commons/webhook/ssrf.go new file mode 100644 index 00000000..46637131 --- /dev/null +++ b/commons/webhook/ssrf.go @@ -0,0 +1,152 @@ +package webhook + +import ( + "context" + "fmt" + "net" + "net/url" + "strings" +) + +// cidr4 constructs a static *net.IPNet for an IPv4 CIDR. Using a helper keeps +// each entry in the table below to a single readable line. All four octets are +// accepted so future CIDR additions don't require changing the signature. +// +//nolint:unparam // d is currently always 0 for these network addresses; kept for generality. +func cidr4(a, b, c, d byte, ones int) *net.IPNet { + return &net.IPNet{ + IP: net.IPv4(a, b, c, d).To4(), + Mask: net.CIDRMask(ones, 32), + } +} + +// cgnatBlock is the CGNAT (Carrier-Grade NAT) range defined by RFC 6598. +// Cloud providers frequently use this range for internal routing, so it must +// be blocked to prevent SSRF via addresses like 100.64.0.1. +// +//nolint:gochecknoglobals // package-level CIDR block is intentional for SSRF protection +var cgnatBlock = cidr4(100, 64, 0, 0, 10) + +// additionalBlockedRanges holds CIDR blocks that are not covered by the +// standard net.IP predicates (IsPrivate, IsLoopback, etc.) but must be +// blocked to prevent SSRF attacks. +// +// All entries are compile-time-constructed net.IPNet literals — no runtime +// string parsing, no init() required, and typos surface as test failures +// rather than startup panics. +// +//nolint:gochecknoglobals // package-level slice is intentional for SSRF protection +var additionalBlockedRanges = []*net.IPNet{ + cidr4(0, 0, 0, 0, 8), // 0.0.0.0/8 — "this network" (RFC 1122 §3.2.1.3) + cidr4(192, 0, 0, 0, 24), // 192.0.0.0/24 — IETF protocol assignments (RFC 6890) + cidr4(192, 0, 2, 0, 24), // 192.0.2.0/24 — TEST-NET-1 documentation (RFC 5737) + cidr4(198, 18, 0, 0, 15), // 198.18.0.0/15 — benchmarking (RFC 2544) + cidr4(198, 51, 100, 0, 24), // 198.51.100.0/24 — TEST-NET-2 documentation (RFC 5737) + cidr4(203, 0, 113, 0, 24), // 203.0.113.0/24 — TEST-NET-3 documentation (RFC 5737) + cidr4(240, 0, 0, 0, 4), // 240.0.0.0/4 — reserved/future use (RFC 1112) +} + +// resolveAndValidateIP performs a single DNS lookup for the hostname in rawURL, +// validates every resolved IP against the SSRF blocklist, and returns a new URL +// with the hostname replaced by the first resolved IP (DNS pinning). +// +// Combining validation and pinning into one lookup eliminates the TOCTOU window +// that exists when validateResolvedIP and pinResolvedIP are called sequentially: +// a DNS rebinding attack could change the record between those two calls, causing +// the pinned IP to differ from the validated one. +// +// On success it returns: +// - pinnedURL — original URL with the hostname replaced by the first resolved IP. +// - originalHost — the original hostname, for use as the HTTP Host header (TLS SNI). +// +// DNS lookup failures are fail-closed: if the hostname cannot be resolved, the +// URL is rejected. When no resolved IP can be parsed from the DNS response the +// URL is considered unresolvable and an error is returned. +func resolveAndValidateIP(ctx context.Context, rawURL string) (pinnedURL string, originalHost string, err error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", "", fmt.Errorf("%w: %w", ErrInvalidURL, err) + } + + scheme := strings.ToLower(u.Scheme) + if scheme != "http" && scheme != "https" { + return "", "", fmt.Errorf("%w: scheme %q not allowed", ErrSSRFBlocked, scheme) + } + + host := u.Hostname() + if host == "" { + return "", "", fmt.Errorf("%w: empty hostname", ErrInvalidURL) + } + + ips, dnsErr := net.DefaultResolver.LookupHost(ctx, host) + if dnsErr != nil { + return "", "", fmt.Errorf("%w: DNS lookup failed for %s: %w", ErrSSRFBlocked, host, dnsErr) + } + + if len(ips) == 0 { + return "", "", fmt.Errorf("%w: DNS returned no addresses for %s", ErrSSRFBlocked, host) + } + + var firstValidIP string + + for _, ipStr := range ips { + ip := net.ParseIP(ipStr) + if ip == nil { + continue + } + + if isPrivateIP(ip) { + return "", "", fmt.Errorf("%w: resolved IP %s is private/loopback", ErrSSRFBlocked, ipStr) + } + + if firstValidIP == "" { + firstValidIP = ipStr + } + } + + if firstValidIP == "" { + return "", "", fmt.Errorf("%w: no valid IPs resolved for %s", ErrInvalidURL, host) + } + + // Pin to first valid resolved IP to prevent DNS rebinding across retries. + port := u.Port() + + switch { + case port != "": + u.Host = net.JoinHostPort(firstValidIP, port) + case strings.Contains(firstValidIP, ":"): + // Bare IPv6 literal must be bracket-wrapped for url.URL.Host. + u.Host = "[" + firstValidIP + "]" + default: + u.Host = firstValidIP + } + + return u.String(), host, nil +} + +// isPrivateIP reports whether ip is in a private, loopback, link-local, +// unspecified, CGNAT, multicast, or other reserved range that must not be +// contacted by webhook delivery (SSRF protection). +// +// In addition to the ranges covered by the standard net.IP predicates, this +// function checks the additionalBlockedRanges slice which covers RFC-defined +// special-purpose blocks not included in Go's net package. +func isPrivateIP(ip net.IP) bool { + if ip.IsLoopback() || + ip.IsPrivate() || + ip.IsLinkLocalUnicast() || + ip.IsLinkLocalMulticast() || + ip.IsMulticast() || + ip.IsUnspecified() || + cgnatBlock.Contains(ip) { + return true + } + + for _, block := range additionalBlockedRanges { + if block.Contains(ip) { + return true + } + } + + return false +} diff --git a/commons/webhook/ssrf_test.go b/commons/webhook/ssrf_test.go new file mode 100644 index 00000000..456cb2dd --- /dev/null +++ b/commons/webhook/ssrf_test.go @@ -0,0 +1,274 @@ +//go:build unit + +package webhook + +import ( + "context" + "errors" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// isPrivateIP — pure function, no DNS, safe to unit-test exhaustively. +// --------------------------------------------------------------------------- + +func TestIsPrivateIP(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ip string + private bool + }{ + // Loopback + {name: "IPv4 loopback", ip: "127.0.0.1", private: true}, + {name: "IPv6 loopback", ip: "::1", private: true}, + + // RFC 1918 private ranges + {name: "10.0.0.0/8", ip: "10.0.0.1", private: true}, + {name: "10.255.255.255", ip: "10.255.255.255", private: true}, + {name: "172.16.0.0/12", ip: "172.16.0.1", private: true}, + {name: "172.31.255.255", ip: "172.31.255.255", private: true}, + {name: "192.168.0.0/16", ip: "192.168.0.1", private: true}, + {name: "192.168.255.255", ip: "192.168.255.255", private: true}, + + // Link-local unicast + {name: "IPv4 link-local", ip: "169.254.1.1", private: true}, + {name: "IPv6 link-local unicast", ip: "fe80::1", private: true}, + + // Link-local multicast + {name: "IPv4 link-local multicast", ip: "224.0.0.1", private: true}, + + // Unspecified addresses (0.0.0.0 / ::) + {name: "IPv4 unspecified", ip: "0.0.0.0", private: true}, + {name: "IPv6 unspecified", ip: "::", private: true}, + + // CGNAT range (RFC 6598): 100.64.0.0/10 + {name: "CGNAT low end", ip: "100.64.0.1", private: true}, + {name: "CGNAT mid range", ip: "100.100.100.100", private: true}, + {name: "CGNAT high end", ip: "100.127.255.254", private: true}, + // Just below CGNAT range — should NOT be private + {name: "100.63.255.255 not CGNAT", ip: "100.63.255.255", private: false}, + // Just above CGNAT range — should NOT be private + {name: "100.128.0.1 not CGNAT", ip: "100.128.0.1", private: false}, + + // Public IPs — should NOT be private + {name: "Google DNS", ip: "8.8.8.8", private: false}, + {name: "Cloudflare", ip: "1.1.1.1", private: false}, + {name: "Public IPv6", ip: "2001:4860:4860::8888", private: false}, + + // Edge: 172.15.x is NOT private (just below 172.16) + {name: "172.15.255.255 not private", ip: "172.15.255.255", private: false}, + // Edge: 172.32.x is NOT private (just above 172.31) + {name: "172.32.0.1 not private", ip: "172.32.0.1", private: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ip := net.ParseIP(tt.ip) + assert.NotNil(t, ip, "failed to parse IP: %s", tt.ip) + assert.Equal(t, tt.private, isPrivateIP(ip), + "isPrivateIP(%s) = %v, want %v", tt.ip, isPrivateIP(ip), tt.private) + }) + } +} + +// --------------------------------------------------------------------------- +// resolveAndValidateIP — URL parsing edge cases only. We do NOT hit real DNS. +// --------------------------------------------------------------------------- + +func TestResolveAndValidateIP_InvalidURLMalformed(t *testing.T) { + t.Parallel() + + _, _, err := resolveAndValidateIP(context.Background(), "://missing-scheme") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) +} + +func TestResolveAndValidateIP_EmptyHostnameHTTP(t *testing.T) { + t.Parallel() + + _, _, err := resolveAndValidateIP(context.Background(), "http://") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) + assert.Contains(t, err.Error(), "empty hostname") +} + +// --------------------------------------------------------------------------- +// resolveAndValidateIP — scheme validation (blocks non-HTTP/HTTPS schemes). +// --------------------------------------------------------------------------- + +func TestResolveAndValidateIP_UnsupportedSchemes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "gopher scheme", url: "gopher://evil.com"}, + {name: "file scheme", url: "file:///etc/passwd"}, + {name: "ftp scheme", url: "ftp://example.com/file"}, + {name: "javascript scheme", url: "javascript:alert(1)"}, + {name: "data scheme", url: "data:text/html,

hi

"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, _, err := resolveAndValidateIP(context.Background(), tt.url) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrSSRFBlocked) + }) + } +} + +func TestResolveAndValidateIP_AllowedSchemes(t *testing.T) { + t.Parallel() + + // These schemes should pass the scheme check (they may still fail DNS + // resolution, but they should NOT fail with ErrSSRFBlocked for scheme). + tests := []struct { + name string + url string + }{ + {name: "http scheme", url: "http://example.com"}, + {name: "https scheme", url: "https://example.com"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, _, err := resolveAndValidateIP(context.Background(), tt.url) + // The error, if any, should NOT be ErrSSRFBlocked for scheme reasons. + if err != nil { + assert.False(t, errors.Is(err, ErrSSRFBlocked), "http/https should not be SSRF-blocked: %v", err) + } + }) + } +} + +// --------------------------------------------------------------------------- +// resolveAndValidateIP — URL parsing and scheme blocking (no DNS) +// --------------------------------------------------------------------------- + +// TestResolveAndValidateIP_InvalidScheme checks that non-HTTP/HTTPS schemes are +// rejected before DNS lookup. +func TestResolveAndValidateIP_InvalidScheme(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "gopher scheme", url: "gopher://example.com"}, + {name: "file scheme", url: "file:///etc/passwd"}, + {name: "ftp scheme", url: "ftp://example.com/file"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, _, err := resolveAndValidateIP(context.Background(), tt.url) + require.Error(t, err) + assert.ErrorIs(t, err, ErrSSRFBlocked, + "non-HTTP/HTTPS scheme must return ErrSSRFBlocked") + }) + } +} + +// TestResolveAndValidateIP_EmptyHostname checks that an empty hostname is rejected. +func TestResolveAndValidateIP_EmptyHostname(t *testing.T) { + t.Parallel() + + _, _, err := resolveAndValidateIP(context.Background(), "http://") + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) + assert.Contains(t, err.Error(), "empty hostname") +} + +// TestResolveAndValidateIP_InvalidURL checks that a completely malformed URL +// is rejected with ErrInvalidURL. +func TestResolveAndValidateIP_InvalidURL(t *testing.T) { + t.Parallel() + + _, _, err := resolveAndValidateIP(context.Background(), "://no-scheme") + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidURL) +} + +// TestResolveAndValidateIP_PrivateIP confirms that hostnames that resolve to a +// loopback/private address are blocked. "localhost" resolves to 127.0.0.1 +// (loopback), which must trigger ErrSSRFBlocked. +// +// This test does perform a real DNS lookup for "localhost". On all POSIX +// systems "localhost" is defined in /etc/hosts as 127.0.0.1, so the lookup +// is local and requires no network access. Environments that strip /etc/hosts +// may see the DNS fall-back path (original URL returned, no error), in which +// case the test is skipped gracefully. +func TestResolveAndValidateIP_PrivateIP(t *testing.T) { + t.Parallel() + + // Probe the local DNS before asserting — skip if localhost doesn't resolve. + addrs, lookupErr := net.LookupHost("localhost") + if lookupErr != nil || len(addrs) == 0 { + t.Skip("localhost DNS lookup failed or returned no results — skipping SSRF private-IP test") + } + + _, _, err := resolveAndValidateIP(context.Background(), "http://localhost") + require.Error(t, err) + assert.ErrorIs(t, err, ErrSSRFBlocked, + "localhost (127.0.0.1) must be blocked as a private/loopback address") +} + +// --------------------------------------------------------------------------- +// isPrivateIP — additional blocked ranges (RFC-defined special-purpose blocks) +// --------------------------------------------------------------------------- + +// TestResolveAndValidateIP_AllBlockedRanges tests isPrivateIP directly against +// representative IPs from every additional CIDR block listed in ssrf.go: +// +// - 0.0.0.1 → 0.0.0.0/8 "this network" (RFC 1122 §3.2.1.3) +// - 192.0.0.1 → 192.0.0.0/24 IETF protocol assignments (RFC 6890) +// - 192.0.2.1 → 192.0.2.0/24 TEST-NET-1 documentation (RFC 5737) +// - 198.18.0.1 → 198.18.0.0/15 benchmarking (RFC 2544) +// - 198.51.100.1 → 198.51.100.0/24 TEST-NET-2 documentation (RFC 5737) +// - 203.0.113.1 → 203.0.113.0/24 TEST-NET-3 documentation (RFC 5737) +// - 240.0.0.1 → 240.0.0.0/4 reserved / future use (RFC 1112) +// - 239.255.255.255 → multicast (net.IP.IsMulticast) +func TestResolveAndValidateIP_AllBlockedRanges(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ip string + }{ + {name: "0/8 this-network", ip: "0.0.0.1"}, + {name: "192.0.0/24 IETF", ip: "192.0.0.1"}, + {name: "192.0.2/24 TEST-NET-1", ip: "192.0.2.1"}, + {name: "198.18/15 benchmarking", ip: "198.18.0.1"}, + {name: "198.51.100/24 TEST-NET-2", ip: "198.51.100.1"}, + {name: "203.0.113/24 TEST-NET-3", ip: "203.0.113.1"}, + {name: "240/4 reserved", ip: "240.0.0.1"}, + {name: "multicast 239.255.255.255", ip: "239.255.255.255"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ip := net.ParseIP(tt.ip) + require.NotNil(t, ip, "test setup error: %s is not a valid IP", tt.ip) + assert.True(t, isPrivateIP(ip), + "isPrivateIP(%s) must return true for blocked range %s", tt.ip, tt.name) + }) + } +} diff --git a/commons/webhook/types.go b/commons/webhook/types.go new file mode 100644 index 00000000..8124cd5a --- /dev/null +++ b/commons/webhook/types.go @@ -0,0 +1,69 @@ +package webhook + +import "context" + +// Endpoint represents a webhook receiver URL with an optional signing secret. +type Endpoint struct { + // ID uniquely identifies this endpoint (used in logs and metrics). + ID string + + // URL is the HTTP(S) endpoint that receives webhook events. + URL string + + // Secret is the HMAC-SHA256 signing secret. May be plaintext or encrypted + // with an "enc:" prefix when a SecretDecryptor is configured. + Secret string + + // Active indicates whether this endpoint should receive deliveries. + Active bool +} + +// Event represents a webhook event to be delivered. +type Event struct { + // Type identifies the event kind (e.g., "order.created", "payment.completed"). + Type string + + // Payload is the JSON-encoded event body. + Payload []byte + + // Timestamp is the Unix epoch seconds when the event was produced. + Timestamp int64 +} + +// EndpointLister retrieves active webhook endpoints for the current context. +// Implementations typically query a database filtered by tenant ID extracted +// from the context. +type EndpointLister interface { + ListActiveEndpoints(ctx context.Context) ([]Endpoint, error) +} + +// DeliveryResult captures the outcome of a single endpoint delivery attempt. +type DeliveryResult struct { + // EndpointID is the ID of the endpoint that was targeted. + EndpointID string + + // StatusCode is the HTTP response status, or 0 if the request failed before + // receiving a response. + StatusCode int + + // Success is true when the endpoint returned a 2xx status code. + Success bool + + // Error is non-nil when delivery failed after all retries. + Error error + + // Attempts is the total number of HTTP requests made (initial + retries). + Attempts int +} + +// DeliveryMetrics allows callers to record webhook delivery outcomes for +// monitoring and alerting. Implementations typically emit OpenTelemetry +// metrics or Prometheus counters. +type DeliveryMetrics interface { + RecordDelivery(ctx context.Context, endpointID string, success bool, statusCode int, attempts int) +} + +// SecretDecryptor decrypts an encrypted endpoint secret. The input carries the +// "enc:" prefix stripped — only the ciphertext is passed. Returning an error +// aborts delivery to that endpoint (fail-closed). +type SecretDecryptor func(encrypted string) (string, error) diff --git a/docs/PROJECT_RULES.md b/docs/PROJECT_RULES.md index 9884f81a..ad5e8b18 100644 --- a/docs/PROJECT_RULES.md +++ b/docs/PROJECT_RULES.md @@ -26,34 +26,68 @@ lib-commons/ ├── commons/ # All library packages │ ├── assert/ # Production-safe assertions with telemetry │ ├── backoff/ # Exponential backoff with jitter +│ ├── certificate/ # Thread-safe TLS certificate manager with hot-reload │ ├── circuitbreaker/ # Circuit breaker manager and health checker │ ├── constants/ # Shared constants (headers, errors, pagination) │ ├── cron/ # Cron expression parsing and scheduling │ ├── crypto/ # Hashing and symmetric encryption +│ ├── dlq/ # Redis-backed dead letter queue with consumer and retry │ ├── errgroup/ # Goroutine coordination with panic recovery +│ ├── internal/ # Internal packages (not part of public API) +│ │ └── nilcheck/ # Nil interface detection helpers │ ├── jwt/ # HMAC-based JWT signing and verification │ ├── license/ # License validation and enforcement │ ├── log/ # Logging abstraction (Logger interface) │ ├── mongo/ # MongoDB connector │ ├── net/http/ # Fiber-oriented HTTP helpers and middleware +│ │ ├── idempotency/ # Fiber idempotency middleware (Redis-backed, tenant-scoped) │ │ └── ratelimit/ # Redis-backed rate limit storage │ ├── opentelemetry/ # Telemetry bootstrap, propagation, redaction │ │ └── metrics/ # Metric factory and fluent builders +│ ├── outbox/ # Transactional outbox primitives +│ │ └── postgres/ # PostgreSQL outbox adapter with migrations │ ├── pointers/ # Pointer conversion helpers │ ├── postgres/ # PostgreSQL connector with migrations │ ├── rabbitmq/ # RabbitMQ connector │ ├── redis/ # Redis connector (standalone/sentinel/cluster) │ ├── runtime/ # Panic recovery, metrics, safe goroutine wrappers │ ├── safe/ # Panic-free math/regex/slice operations +│ ├── secretsmanager/ # AWS Secrets Manager M2M credential retrieval │ ├── security/ # Sensitive field detection and handling │ ├── server/ # Graceful shutdown and lifecycle (ServerManager) │ ├── shell/ # Makefile includes and shell utilities +│ ├── systemplane/ # Runtime configuration plane (hot-reloadable settings) +│ │ ├── adapters/ # Store (postgres, mongodb) and changefeed adapters +│ │ ├── bootstrap/ # Environment-based config loading +│ │ ├── domain/ # Domain types, entries, revisions, snapshots +│ │ ├── ports/ # Port interfaces (store, changefeed, history, reconciler) +│ │ ├── registry/ # Configuration key registry and validation +│ │ ├── service/ # Service manager, supervisor, escalation +│ │ └── testutil/ # Test fakes for systemplane contracts +│ ├── tenant-manager/ # Multi-tenant database-per-tenant isolation +│ │ ├── cache/ # In-memory tenant cache with LRU eviction +│ │ ├── client/ # HTTP client for tenant-manager API +│ │ ├── consumer/ # Multi-tenant consumer with lazy loading and retry +│ │ ├── core/ # Core types, context, errors, validation +│ │ ├── event/ # Event listener, dispatcher, payloads (Redis Pub/Sub) +│ │ ├── log/ # Tenant-scoped logger +│ │ ├── middleware/ # Fiber middleware (TenantMiddleware with WithPG/WithMB) +│ │ ├── mongo/ # MongoDB tenant manager +│ │ ├── postgres/ # PostgreSQL tenant manager +│ │ ├── rabbitmq/ # RabbitMQ tenant manager +│ │ ├── redis/ # Redis tenant client +│ │ ├── s3/ # S3 object storage for tenant provisioning scripts +│ │ ├── tenantcache/ # Tenant cache and loader +│ │ └── valkey/ # Valkey/Redis key patterns │ ├── transaction/ # Typed transaction validation/posting primitives +│ ├── webhook/ # Webhook delivery with HMAC-SHA256 signing and SSRF protection │ ├── zap/ # Zap logging adapter │ ├── app.go # Application bootstrap helpers │ ├── context.go # Context utilities +│ ├── environment.go # Environment detection and security tier mapping │ ├── errors.go # Error definitions -│ ├── os.go # OS utilities +│ ├── os.go # OS utilities and env var helpers +│ ├── security_override.go # ALLOW_* security policy override mechanism │ ├── stringUtils.go # String utilities │ ├── time.go # Time utilities │ └── utils.go # General utility functions @@ -351,15 +385,15 @@ func (c *Client) Connect(ctx context.Context) error { | Category | Allowed Packages | |----------|-----------------| -| Database | `pgx/v5`, `mongo-driver`, `go-redis/v9`, `dbresolver/v2`, `golang-migrate/v4` | +| Database | `pgx/v5`, `mongo-driver`, `mongo-driver/v2`, `go-redis/v9`, `dbresolver/v2`, `golang-migrate/v4` | | Messaging | `amqp091-go` | | HTTP | `gofiber/fiber/v2` | | Logging | `zap`, internal `log` package | -| Testing | `testify`, `go.uber.org/mock`, `miniredis/v2` | -| Observability | `opentelemetry/*`, `otelzap` | -| Utilities | `google/uuid`, `shopspring/decimal`, `go-playground/validator/v10` | +| Testing | `testify`, `go.uber.org/mock`, `miniredis/v2`, `testcontainers-go`, `go-sqlmock`, `goleak` | +| Observability | `opentelemetry/*`, `otelzap`, `grpc`, `protobuf` | +| Utilities | `google/uuid`, `shopspring/decimal`, `go-playground/validator/v10`, `golang.org/x/sync`, `golang.org/x/text` | | Resilience | `sony/gobreaker`, `go-redsync/v4` | -| Security | `golang.org/x/oauth2`, `google.golang.org/api` | +| Security | `golang.org/x/oauth2`, `google.golang.org/api`, `golang-jwt/jwt/v5`, `aws-sdk-go-v2` (secretsmanager) | | System | `shirou/gopsutil`, `joho/godotenv` | ### Forbidden Dependencies @@ -411,7 +445,8 @@ safeValue := redactor.Redact(sensitiveField) ### Environment Variables -- Use `SECURE_LOG_FIELDS` for field obfuscation +- Use `LOG_OBFUSCATION_DISABLED` to control HTTP body obfuscation (default: disabled) +- Sensitive field detection uses `commons/security.IsSensitiveField()` with a hardcoded set - Document required environment variables - Provide sensible defaults where safe @@ -428,7 +463,17 @@ safeValue := redactor.Redact(sensitiveField) ### Enabled Linters -`bodyclose`, `depguard`, `dogsled`, `dupword`, `errchkjson`, `gocognit`, `gocyclo`, `loggercheck`, `misspell`, `nakedret`, `nilerr`, `nolintlint`, `prealloc`, `predeclared`, `reassign`, `revive`, `staticcheck`, `thelper`, `tparallel`, `unconvert`, `unparam`, `usestdlibvars`, `wastedassign`, `wsl_v5` +**Existing linters:** +`bodyclose`, `depguard`, `dogsled`, `dupword`, `errchkjson`, `gocognit`, `gocyclo`, `loggercheck`, `misspell`, `nakedret`, `nilerr`, `nolintlint`, `prealloc`, `predeclared`, `reassign`, `revive`, `staticcheck`, `unconvert`, `unparam`, `usestdlibvars`, `wastedassign`, `wsl_v5` + +**Tier 1 — Safety & Correctness:** +`errorlint`, `exhaustive`, `fatcontext`, `forcetypeassert`, `gosec`, `nilnil`, `noctx` + +**Tier 2 — Code Quality & Modernization:** +`goconst`, `gocritic`, `inamedparam`, `intrange`, `mirror`, `modernize`, `perfsprint` + +**Tier 3 — Zero-Issue Guards:** +`asasalint`, `copyloopvar`, `durationcheck`, `exptostd`, `gocheckcompilerdirectives`, `makezero`, `musttag`, `nilnesserr`, `recvcheck`, `rowserrcheck`, `spancheck`, `sqlclosecheck`, `testifylint` ### Formatting @@ -491,7 +536,7 @@ make clean # Clean all build artifacts ## API Invariants -Key v2 API contracts that must be preserved: +Key v4 API contracts that must be preserved: | Package | Invariant | |---------|-----------| @@ -515,6 +560,19 @@ Key v2 API contracts that must be preserved: | `transaction` | `BuildIntentPlan()` + `ValidateBalanceEligibility()` + `ApplyPosting()` | | `rabbitmq` | `*Context()` variants for lifecycle; `HealthCheck()` returns `(bool, error)` | | `opentelemetry` | `Redactor` with `RedactionRule`; `NewDefaultRedactor()` / `NewRedactor(rules, mask)` | +| `certificate` | `NewManager(certPath, keyPath)` — empty paths return unconfigured manager (TLS optional). `Rotate(cert, key)` for zero-downtime hot-reload. `GetCertificateFunc()` for `tls.Config.GetCertificate`. All methods nil-safe. | +| `certificate` | Private key parsing order: PKCS#8 → PKCS#1 → EC. Key file must have mode 0600 or stricter (enforced at load time). Public key match validated against cert at load and rotate. | +| `dlq` | `New(conn, keyPrefix, maxRetries, opts...)` returns `*Handler`; `NewConsumer(handler, retryFn, opts...)` returns `(*Consumer, error)`. | +| `dlq` | `Handler.Enqueue(ctx, msg)`, `Dequeue(ctx, source)`, `QueueLength(ctx, source)`, `ScanQueues(ctx, source)`, `PruneExhaustedMessages(ctx, source, limit)`, `ExtractTenantFromKey(key, source)`. | +| `dlq` | `Consumer.Run(ctx)` blocks until ctx cancelled or `Stop()` called. `ProcessOnce(ctx)` exported for testing. Tenant isolation via `tmcore.GetTenantIDContext`. | +| `dlq` | `DLQMetrics` interface: `RecordRetried(ctx, source)`, `RecordExhausted(ctx, source)`. Nil metrics are silently skipped. | +| `net/http/idempotency` | `New(conn, opts...)` returns `*Middleware` (nil when conn is nil). `(*Middleware).Check()` returns a Fiber handler; nil receiver returns pass-through. | +| `net/http/idempotency` | Redis key pattern: `:` with companion response key `…:response`. Default prefix `"idempotency:"`, TTL 7 days, max key length 256. | +| `net/http/idempotency` | Fail-open on Redis errors. GET/OPTIONS/HEAD requests pass through. Handler error deletes key (client may retry). In-flight duplicate returns 409 Conflict. | +| `webhook` | `NewDeliverer(lister, opts...)` returns `*Deliverer` (nil when lister is nil). `Deliver(ctx, event)` and `DeliverWithResults(ctx, event)` are the delivery entry points. | +| `webhook` | SSRF protection: `resolveAndValidateIP` performs a single DNS lookup, validates all resolved IPs against private/loopback/CGNAT/RFC-reserved ranges, and pins the URL to the first IP to eliminate TOCTOU. Redirects are blocked at transport layer. | +| `webhook` | HMAC-SHA256 signature sent in `X-Webhook-Signature: sha256=HEX` header over raw payload bytes only (timestamp excluded by design). Encrypted secrets use `enc:` prefix and require `WithSecretDecryptor`. | +| `webhook` | `EndpointLister` interface: `ListActiveEndpoints(ctx) ([]Endpoint, error)`. `DeliveryMetrics` interface: `RecordDelivery(ctx, endpointID, success, statusCode, attempts)`. | ---