diff --git a/internal/observability/request_logger.go b/internal/observability/request_logger.go index 7f6044e..21d87a5 100644 --- a/internal/observability/request_logger.go +++ b/internal/observability/request_logger.go @@ -6,6 +6,7 @@ package observability import ( "log/slog" "net/http" + "net/url" "strings" "time" ) @@ -83,9 +84,8 @@ func (rc *ResponseCapturer) Status() int { // // Parameters: // - logger: structured logger for output -// - vendorHeader: the full header name for vendor ID (e.g., "X-Connect-Vendor-ID") -// constructed by the caller using the configured prefix, so header name -// derivation stays in one place alongside ParseContext (DRY / ADR-005) +// - headerPrefix: the context header prefix (e.g., "X-Connect"). The middleware +// constructs header names internally using the stable suffix constants. // - next: the downstream handler // // Fields emitted: @@ -94,7 +94,10 @@ func (rc *ResponseCapturer) Status() int { // - path: Request path // - status: Response status code // - latency_ms: Time to process request in milliseconds -// - vendor_id: Vendor ID from the vendorHeader request header +// - vendor_id: Vendor ID from the -Vendor-ID request header +// - marketplace_id: Marketplace ID from the -Marketplace-ID request header +// - product_id: Product ID from the -Product-ID request header +// - target_host: Host extracted from the -Target-URL request header // - client_ip: Client IP from proxy headers (X-Forwarded-For > X-Real-IP); // empty when no proxy headers are present (use remote_addr instead) // - remote_addr: Raw TCP peer address (always r.RemoteAddr, useful for @@ -104,7 +107,12 @@ func (rc *ResponseCapturer) Status() int { // already in the request context when this handler receives the request. // Uses defer to ensure logging occurs even if downstream handlers panic // (when used with panic recovery middleware). -func RequestLoggerMiddleware(logger *slog.Logger, vendorHeader string, next http.Handler) http.Handler { +func RequestLoggerMiddleware(logger *slog.Logger, headerPrefix string, next http.Handler) http.Handler { + vendorHdr := headerPrefix + "-Vendor-ID" + marketplaceHdr := headerPrefix + "-Marketplace-ID" + productHdr := headerPrefix + "-Product-ID" + targetURLHdr := headerPrefix + "-Target-URL" + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() ctx := r.Context() @@ -120,7 +128,10 @@ func RequestLoggerMiddleware(logger *slog.Logger, vendorHeader string, next http "path", r.URL.Path, "status", capturer.Status(), "latency_ms", time.Since(start).Milliseconds(), - "vendor_id", r.Header.Get(vendorHeader), + "vendor_id", r.Header.Get(vendorHdr), + "marketplace_id", r.Header.Get(marketplaceHdr), + "product_id", r.Header.Get(productHdr), + "target_host", extractHost(r.Header.Get(targetURLHdr)), "client_ip", ClientIP(r), "remote_addr", r.RemoteAddr, ) @@ -130,6 +141,16 @@ func RequestLoggerMiddleware(logger *slog.Logger, vendorHeader string, next http }) } +// extractHost parses rawURL and returns only the host (with port if present). +// Returns an empty string if the URL is invalid or has no host. +func extractHost(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil || u.Host == "" { + return "" + } + return u.Host +} + // ClientIP extracts the client IP from proxy headers only. // Returns the first IP from X-Forwarded-For, or X-Real-IP as fallback. // Returns "" when no proxy headers are present — in that case the diff --git a/internal/observability/request_logger_test.go b/internal/observability/request_logger_test.go index 13f5706..9d99e6a 100644 --- a/internal/observability/request_logger_test.go +++ b/internal/observability/request_logger_test.go @@ -9,22 +9,26 @@ import ( "log/slog" "net/http" "net/http/httptest" + "strings" "testing" ) // logEntry represents a parsed JSON log line for test assertions. type logEntry struct { - Time string `json:"time"` - Level string `json:"level"` - Msg string `json:"msg"` - TraceID string `json:"trace_id"` - Method string `json:"method"` - Path string `json:"path"` - Status int `json:"status"` - LatencyMs int64 `json:"latency_ms"` - VendorID string `json:"vendor_id"` - ClientIP string `json:"client_ip"` - RemoteAddr string `json:"remote_addr"` + Time string `json:"time"` + Level string `json:"level"` + Msg string `json:"msg"` + TraceID string `json:"trace_id"` + Method string `json:"method"` + Path string `json:"path"` + Status int `json:"status"` + LatencyMs int64 `json:"latency_ms"` + VendorID string `json:"vendor_id"` + MarketplaceID string `json:"marketplace_id"` + ProductID string `json:"product_id"` + TargetHost string `json:"target_host"` + ClientIP string `json:"client_ip"` + RemoteAddr string `json:"remote_addr"` } // parseLogEntry parses a single JSON log line from the buffer. @@ -130,10 +134,13 @@ func TestRequestLoggerMiddleware_LogsRequestFields(t *testing.T) { w.WriteHeader(http.StatusOK) }) - handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", inner) r := httptest.NewRequest(http.MethodPost, "/proxy", nil) r = r.WithContext(WithTraceID(r.Context(), "test-trace-123")) r.Header.Set("X-Connect-Vendor-ID", "microsoft") + r.Header.Set("X-Connect-Marketplace-ID", "MP-US") + r.Header.Set("X-Connect-Product-ID", "PRD-001") + r.Header.Set("X-Connect-Target-URL", "https://graph.microsoft.com/v1.0/users") r.RemoteAddr = "10.0.0.1:54321" w := httptest.NewRecorder() @@ -161,6 +168,15 @@ func TestRequestLoggerMiddleware_LogsRequestFields(t *testing.T) { if entry.VendorID != "microsoft" { t.Errorf("vendor_id = %q, want %q", entry.VendorID, "microsoft") } + if entry.MarketplaceID != "MP-US" { + t.Errorf("marketplace_id = %q, want %q", entry.MarketplaceID, "MP-US") + } + if entry.ProductID != "PRD-001" { + t.Errorf("product_id = %q, want %q", entry.ProductID, "PRD-001") + } + if entry.TargetHost != "graph.microsoft.com" { + t.Errorf("target_host = %q, want %q", entry.TargetHost, "graph.microsoft.com") + } if entry.ClientIP != "" { t.Errorf("client_ip = %q, want empty (no proxy headers)", entry.ClientIP) } @@ -177,7 +193,7 @@ func TestRequestLoggerMiddleware_CapturesErrorStatus(t *testing.T) { w.WriteHeader(http.StatusBadGateway) }) - handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", inner) r := httptest.NewRequest(http.MethodGet, "/proxy", nil) r = r.WithContext(WithTraceID(r.Context(), "err-trace")) w := httptest.NewRecorder() @@ -198,7 +214,7 @@ func TestRequestLoggerMiddleware_NoTraceID_LogsEmpty(t *testing.T) { w.WriteHeader(http.StatusOK) }) - handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", inner) r := httptest.NewRequest(http.MethodGet, "/proxy", nil) // Deliberately no trace ID in context w := httptest.NewRecorder() @@ -220,7 +236,7 @@ func TestRequestLoggerMiddleware_LogsOnPanic(t *testing.T) { }) // Wrap with panic recovery INSIDE request logger, so logger still fires - handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", panicRecoveryForTest(panicky)) + handler := RequestLoggerMiddleware(logger, "X-Connect", panicRecoveryForTest(panicky)) r := httptest.NewRequest(http.MethodGet, "/proxy", nil) r = r.WithContext(WithTraceID(r.Context(), "panic-trace")) w := httptest.NewRecorder() @@ -244,7 +260,7 @@ func TestRequestLoggerMiddleware_ClientIPFromXForwardedFor(t *testing.T) { w.WriteHeader(http.StatusOK) }) - handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", inner) r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("X-Forwarded-For", "203.0.113.50") w := httptest.NewRecorder() @@ -269,7 +285,7 @@ func TestRequestLoggerMiddleware_TraceIDFromOuterMiddleware(t *testing.T) { }) // Production order: TraceID (outermost) → Logger → handler - handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", inner) handler = TraceIDMiddleware("X-Trace-ID", handler) r := httptest.NewRequest(http.MethodGet, "/test", nil) @@ -295,7 +311,7 @@ func TestRequestLoggerMiddleware_TraceIDGenerated(t *testing.T) { }) // Production order: TraceID (outermost) → Logger → handler - handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", inner) handler = TraceIDMiddleware("X-Trace-ID", handler) r := httptest.NewRequest(http.MethodGet, "/test", nil) @@ -310,6 +326,86 @@ func TestRequestLoggerMiddleware_TraceIDGenerated(t *testing.T) { } } +func TestRequestLoggerMiddleware_LogsTargetHost_ExtractsHostOnly(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := RequestLoggerMiddleware(logger, "X-Connect", inner) + r := httptest.NewRequest(http.MethodGet, "/proxy", nil) + // URL with query string and path — only the host should appear in the log + r.Header.Set("X-Connect-Target-URL", "https://api.vendor.com/v1/users?api_key=secret&token=abc") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, r) + + entry := parseLogEntry(t, buf.Bytes()) + if entry.TargetHost != "api.vendor.com" { + t.Errorf("target_host = %q, want %q (only host, no path/query)", entry.TargetHost, "api.vendor.com") + } + // Verify query params did not leak into any logged field + logOutput := buf.String() + if strings.Contains(logOutput, "api_key") || strings.Contains(logOutput, "secret") { + t.Errorf("log should not contain query params, got: %s", logOutput) + } +} + +func TestExtractHost(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "full URL with path", + input: "https://api.vendor.com/v1/users", + want: "api.vendor.com", + }, + { + name: "URL with port", + input: "https://api.vendor.com:8443/v1", + want: "api.vendor.com:8443", + }, + { + name: "URL with query string", + input: "https://api.vendor.com/v1?key=secret", + want: "api.vendor.com", + }, + { + name: "URL without path", + input: "https://api.vendor.com", + want: "api.vendor.com", + }, + { + name: "empty URL returns empty", + input: "", + want: "", + }, + { + name: "invalid URL returns empty", + input: "://invalid", + want: "", + }, + { + name: "path-only URL returns empty", + input: "/just/a/path", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractHost(tt.input) + if got != tt.want { + t.Errorf("extractHost(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + // panicRecoveryForTest wraps a handler with basic panic recovery for testing. func panicRecoveryForTest(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/proxy/integration_test.go b/internal/proxy/integration_test.go index 1e6039a..460b371 100644 --- a/internal/proxy/integration_test.go +++ b/internal/proxy/integration_test.go @@ -4,6 +4,7 @@ package proxy_test import ( + "bytes" "context" "errors" "fmt" @@ -427,11 +428,10 @@ func TestIntegration_BackendUnreachable_Returns502(t *testing.T) { } } -func TestIntegration_PluginContextCanceled_NoResponse(t *testing.T) { - // Arrange - plugin that checks for context cancellation +func TestIntegration_PluginContextCanceled_Returns499(t *testing.T) { + // Arrange - plugin that simulates client disconnect plugin := &mockPlugin{ - getCredentialsFn: func(ctx context.Context, tx sdk.TransactionContext, req *http.Request) (*sdk.Credential, error) { - // Simulate context being cancelled (client disconnected) + getCredentialsFn: func(_ context.Context, _ sdk.TransactionContext, _ *http.Request) (*sdk.Credential, error) { return nil, context.Canceled }, } @@ -449,11 +449,9 @@ func TestIntegration_PluginContextCanceled_NoResponse(t *testing.T) { // Act handler.ServeHTTP(rec, req) - // Assert - when context is canceled, we don't write a response (or write minimal) - // The important thing is we don't crash and don't return 500 - // In practice, client won't see this response since they disconnected - if rec.Code == http.StatusInternalServerError { - t.Error("context.Canceled should not return 500") + // Assert - 499 Client Closed Request (nginx convention for client disconnect) + if rec.Code != proxy.StatusClientClosedRequest { + t.Errorf("status = %d, want %d (StatusClientClosedRequest)", rec.Code, proxy.StatusClientClosedRequest) } } @@ -1909,3 +1907,212 @@ func TestIntegration_NonContextHeaders_PreservedOnForwarding(t *testing.T) { t.Errorf("Accept = %q, want %q", receivedAccept, "application/json") } } + +// ============================================================================= +// Phase 2: DEBUG logging for credential injection +// ============================================================================= + +func TestIntegration_FastPath_LogsCredentialInjection(t *testing.T) { + // Arrange - capture DEBUG log output + var logBuffer bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) + originalLogger := slog.Default() + slog.SetDefault(logger) + defer slog.SetDefault(originalLogger) + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + plugin := &mockPlugin{ + getCredentialsFn: func(_ context.Context, _ sdk.TransactionContext, _ *http.Request) (*sdk.Credential, error) { + return &sdk.Credential{ + Headers: map[string]string{"Authorization": "Bearer test-token"}, + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + } + + cfg := testConfig() + cfg.Plugin = plugin + srv := mustNewServerForTarget(t, cfg, backend.URL) + handler := srv.Handler() + + req := httptest.NewRequest(http.MethodGet, "/proxy", nil) + req.Header.Set("X-Connect-Target-URL", backend.URL) + req.Header.Set("X-Connect-Vendor-ID", "VA-test") + req.Header.Set("X-Connect-Marketplace-ID", "MP-test") + rec := httptest.NewRecorder() + + // Act + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + // Assert - "credentials injected" DEBUG log with Fast Path fields + logOutput := logBuffer.String() + if !strings.Contains(logOutput, `"msg":"credentials injected"`) { + t.Errorf("expected credentials injected log, got: %s", logOutput) + } + if !strings.Contains(logOutput, `"credential_path":"fast"`) { + t.Errorf("expected credential_path=fast in log, got: %s", logOutput) + } + if !strings.Contains(logOutput, `"injected_header_count":1`) { + t.Errorf("expected injected_header_count=1 in log, got: %s", logOutput) + } + if !strings.Contains(logOutput, `"vendor_id":"VA-test"`) { + t.Errorf("expected vendor_id in log, got: %s", logOutput) + } + if !strings.Contains(logOutput, `"plugin_duration_ms"`) { + t.Errorf("expected plugin_duration_ms in log, got: %s", logOutput) + } +} + +func TestIntegration_SlowPath_LogsCredentialInjection(t *testing.T) { + // Arrange - capture DEBUG log output + var logBuffer bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) + originalLogger := slog.Default() + slog.SetDefault(logger) + defer slog.SetDefault(originalLogger) + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + // Slow path: plugin mutates request directly and returns nil credential + plugin := &mockPlugin{ + getCredentialsFn: func(_ context.Context, _ sdk.TransactionContext, req *http.Request) (*sdk.Credential, error) { + req.Header.Set("Authorization", "Bearer slow-token") + return nil, nil + }, + } + + cfg := testConfig() + cfg.Plugin = plugin + srv := mustNewServerForTarget(t, cfg, backend.URL) + handler := srv.Handler() + + req := httptest.NewRequest(http.MethodGet, "/proxy", nil) + req.Header.Set("X-Connect-Target-URL", backend.URL) + req.Header.Set("X-Connect-Vendor-ID", "VA-slow") + rec := httptest.NewRecorder() + + // Act + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + // Assert - "credentials injected" DEBUG log with Slow Path fields + logOutput := logBuffer.String() + if !strings.Contains(logOutput, `"msg":"credentials injected"`) { + t.Errorf("expected credentials injected log, got: %s", logOutput) + } + if !strings.Contains(logOutput, `"credential_path":"slow"`) { + t.Errorf("expected credential_path=slow in log, got: %s", logOutput) + } + if !strings.Contains(logOutput, `"vendor_id":"VA-slow"`) { + t.Errorf("expected vendor_id in log, got: %s", logOutput) + } +} + +// ============================================================================= +// Phase 6: DEBUG logpoints for context parsing +// ============================================================================= + +func TestProxy_ContextParsed_DebugLog_LogsHostOnly(t *testing.T) { + // Arrange - capture DEBUG log output + var logBuffer bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) + originalLogger := slog.Default() + slog.SetDefault(logger) + defer slog.SetDefault(originalLogger) + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + srv := mustNewServerForTarget(t, testConfig(), backend.URL) + handler := srv.Handler() + + // URL with a sensitive path segment and query params — only the host should appear in logs + targetURL := backend.URL + "/v1/users/alice@example.com?api_key=supersecret&token=abc123" + + req := httptest.NewRequest(http.MethodGet, "/proxy", nil) + req.Header.Set("X-Connect-Target-URL", targetURL) + req.Header.Set("X-Connect-Vendor-ID", "VA-test") + rec := httptest.NewRecorder() + + // Act + handler.ServeHTTP(rec, req) + + // Assert - only the host appears; path, query, and userinfo must not leak + logOutput := logBuffer.String() + if !strings.Contains(logOutput, `"msg":"transaction context parsed"`) { + t.Errorf("expected 'transaction context parsed' debug log, got: %s", logOutput) + } + if strings.Contains(logOutput, "supersecret") { + t.Errorf("log must not contain sensitive query value 'supersecret', got: %s", logOutput) + } + if strings.Contains(logOutput, "api_key") { + t.Errorf("log must not contain query param name 'api_key', got: %s", logOutput) + } + if strings.Contains(logOutput, "token=") { + t.Errorf("log must not contain 'token=' query param, got: %s", logOutput) + } + if strings.Contains(logOutput, "alice@example.com") { + t.Errorf("log must not contain path segment 'alice@example.com', got: %s", logOutput) + } + if strings.Contains(logOutput, "/v1/users") { + t.Errorf("log must not contain URL path, got: %s", logOutput) + } +} + +// ============================================================================= +// Phase 3: Status 499 on client disconnect +// ============================================================================= + +func TestIntegration_ClientDisconnect_LogsStatus499(t *testing.T) { + // Arrange - capture log output + var logBuffer bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&logBuffer, nil)) + originalLogger := slog.Default() + slog.SetDefault(logger) + defer slog.SetDefault(originalLogger) + + plugin := &mockPlugin{ + getCredentialsFn: func(_ context.Context, _ sdk.TransactionContext, _ *http.Request) (*sdk.Credential, error) { + return nil, context.Canceled + }, + } + + cfg := testConfig() + cfg.Plugin = plugin + srv := mustNewServer(t, cfg) + handler := srv.Handler() + + req := httptest.NewRequest(http.MethodPost, "/proxy", nil) + req.Header.Set("X-Connect-Target-URL", "http://example.com") + req.Header.Set("X-Connect-Vendor-ID", "VA-disconnect") + req.Header.Set("Connect-Request-ID", "trace-499-test") + rec := httptest.NewRecorder() + + // Act + handler.ServeHTTP(rec, req) + + // Assert - RequestLoggerMiddleware logs status 499 (not the default 200) + if rec.Code != proxy.StatusClientClosedRequest { + t.Errorf("response status = %d, want %d", rec.Code, proxy.StatusClientClosedRequest) + } + logOutput := logBuffer.String() + if !strings.Contains(logOutput, `"status":499`) { + t.Errorf("log should contain status 499, got: %s", logOutput) + } +} diff --git a/internal/proxy/middleware_bench_test.go b/internal/proxy/middleware_bench_test.go index 79ecd32..4e257da 100644 --- a/internal/proxy/middleware_bench_test.go +++ b/internal/proxy/middleware_bench_test.go @@ -38,7 +38,7 @@ func BenchmarkRequestLoggingMiddleware(b *testing.B) { inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) - handler := observability.RequestLoggerMiddleware(slog.Default(), "X-Connect-Vendor-ID", inner) + handler := observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", inner) req := httptest.NewRequest("GET", "/proxy", nil) b.ReportAllocs() @@ -61,7 +61,7 @@ func BenchmarkMiddlewareStack(b *testing.B) { // Stack middlewares as they would be in production // Order: TraceID (outermost) → Logger → PanicRecovery → handler handler := PanicRecoveryMiddleware(inner) - handler = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect-Vendor-ID", handler) + handler = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", handler) handler = observability.TraceIDMiddleware("Connect-Request-ID", handler) req := httptest.NewRequest("GET", "/proxy", nil) @@ -82,7 +82,7 @@ func BenchmarkMiddlewareStack_Parallel(b *testing.B) { w.WriteHeader(http.StatusOK) }) handler := PanicRecoveryMiddleware(inner) - handler = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect-Vendor-ID", handler) + handler = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", handler) handler = observability.TraceIDMiddleware("Connect-Request-ID", handler) b.ReportAllocs() diff --git a/internal/proxy/server.go b/internal/proxy/server.go index ce3105e..660d448 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -18,6 +18,7 @@ import ( "net/http/httputil" "net/url" "os" + "sort" "sync/atomic" "time" @@ -401,6 +402,10 @@ func (s *Server) logStartup() { ) } + if s.config.Plugin == nil { + slog.Info("no plugin configured, requests will be forwarded without credential injection") + } + if s.config.TLS.Enabled { slog.Info("starting proxy server with mTLS (Mode A)", "addr", s.config.Addr, @@ -409,10 +414,13 @@ func (s *Server) logStartup() { "client_auth", "RequireAndVerifyClientCert", ) slog.Info("server configuration", - "read_timeout", s.config.ReadTimeout, - "write_timeout", s.config.WriteTimeout, - "idle_timeout", s.config.IdleTimeout, - "plugin_timeout", s.config.PluginTimeout, + "read_timeout", s.config.ReadTimeout.String(), + "write_timeout", s.config.WriteTimeout.String(), + "idle_timeout", s.config.IdleTimeout.String(), + "plugin_timeout", s.config.PluginTimeout.String(), + "connect_timeout", s.config.ConnectTimeout.String(), + "keepalive_timeout", s.config.KeepAliveTimeout.String(), + "shutdown_timeout", s.config.ShutdownTimeout.String(), "cert_file", s.config.TLS.CertFile, "ca_file", s.config.TLS.CAFile, ) @@ -422,12 +430,33 @@ func (s *Server) logStartup() { "mode", "B (basic)", ) slog.Info("server configuration", - "read_timeout", s.config.ReadTimeout, - "write_timeout", s.config.WriteTimeout, - "idle_timeout", s.config.IdleTimeout, - "plugin_timeout", s.config.PluginTimeout, + "read_timeout", s.config.ReadTimeout.String(), + "write_timeout", s.config.WriteTimeout.String(), + "idle_timeout", s.config.IdleTimeout.String(), + "plugin_timeout", s.config.PluginTimeout.String(), + "connect_timeout", s.config.ConnectTimeout.String(), + "keepalive_timeout", s.config.KeepAliveTimeout.String(), + "shutdown_timeout", s.config.ShutdownTimeout.String(), ) } + + s.logAllowList() +} + +// logAllowList logs the configured allow-list at startup for operational visibility. +func (s *Server) logAllowList() { + hosts := make([]string, 0, len(s.config.AllowList)) + routeCount := 0 + for host, patterns := range s.config.AllowList { + hosts = append(hosts, host) + routeCount += len(patterns) + } + sort.Strings(hosts) // Deterministic output for stable logs and testability + slog.Info("allow list configured", + "hosts", hosts, + "host_count", len(s.config.AllowList), + "route_count", routeCount, + ) } // withMiddleware wraps the handler with the global middleware stack. @@ -449,7 +478,7 @@ func (s *Server) withMiddleware(handler http.Handler) http.Handler { ) } - handler = observability.RequestLoggerMiddleware(slog.Default(), s.config.HeaderPrefix+"-Vendor-ID", handler) + handler = observability.RequestLoggerMiddleware(slog.Default(), s.config.HeaderPrefix, handler) handler = observability.TraceIDMiddleware(s.config.TraceHeader, handler) return handler } @@ -481,6 +510,14 @@ func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) { return } + slog.Debug("transaction context parsed", + "trace_id", traceID, + "vendor_id", txCtx.VendorID, + "marketplace_id", txCtx.MarketplaceID, + "product_id", txCtx.ProductID, + "target_host", extractTargetHost(txCtx.TargetURL), + ) + targetURL, err := url.Parse(txCtx.TargetURL) if err != nil { s.respondBadRequest(w, traceID, "invalid target URL", err) @@ -507,9 +544,9 @@ func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) { ) } - r, err = s.injectCredentials(r, txCtx) + r, err = s.injectCredentials(r, txCtx, targetURL.Host) if err != nil { - s.handlePluginError(w, traceID, err) + s.handlePluginError(w, traceID, txCtx, targetURL.Host, err) return } @@ -535,10 +572,13 @@ func (s *Server) respondBadRequest(w http.ResponseWriter, traceID, msg string, e // RedactingHandler can detect and redact them if they leak into log output // (value-based scanning, Layers 3 & 4). // +// targetHost is the already-parsed host from the target URL, passed from handleProxy +// to avoid re-parsing and to keep the field in DEBUG log output. +// // Returns the (possibly updated) request and any error. The caller MUST use // the returned request for all subsequent operations, because the context may // have been enriched with secret values and injected header keys. -func (s *Server) injectCredentials(r *http.Request, txCtx *sdk.TransactionContext) (*http.Request, error) { +func (s *Server) injectCredentials(r *http.Request, txCtx *sdk.TransactionContext, targetHost string) (*http.Request, error) { if s.config.Plugin == nil { return r, nil } @@ -571,26 +611,18 @@ func (s *Server) injectCredentials(r *http.Request, txCtx *sdk.TransactionContex return nil, err } - // Fast Path: plugin returned headers to inject + // Fast Path: plugin returned headers to inject. if cred != nil { - reqCtx := r.Context() - injectedKeys := make([]string, 0, len(cred.Headers)) - for k, v := range cred.Headers { - r.Header.Set(k, v) - injectedKeys = append(injectedKeys, k) - // Store each credential value in context for value-based log redaction. - // The RedactingHandler will scan all slog string attrs and messages - // for these values. Short values (< MinSecretLength) are automatically - // skipped by the handler to avoid false positives. - reqCtx = observability.WithSecretValue(reqCtx, v) - } - // Store injected header keys so the Reflector can strip them from - // responses (prevents credential reflection for non-standard headers). - reqCtx = security.WithInjectedHeaders(reqCtx, injectedKeys) - // Propagate enriched context to the returned request. The reverse proxy - // will Clone this request, so resp.Request.Context() in ModifyResponse - // will carry the secret values and injected header keys. - r = r.WithContext(reqCtx) + r = s.applyFastPathCredentials(r, cred) + slog.Debug("credentials injected", + "trace_id", txCtx.TraceID, + "vendor_id", txCtx.VendorID, + "marketplace_id", txCtx.MarketplaceID, + "target_host", targetHost, + "credential_path", "fast", + "injected_header_count", len(cred.Headers), + "plugin_duration_ms", pluginDuration.Milliseconds(), + ) return r, nil } @@ -599,10 +631,45 @@ func (s *Server) injectCredentials(r *http.Request, txCtx *sdk.TransactionContex // against our pre-call snapshot. This ensures log redaction and response // stripping work even if the plugin doesn't call WithSecretValue() or // WithInjectedHeaders() itself. - r = s.detectSlowPathInjections(r, headersBefore) + var injectedCount int + r, injectedCount = s.detectSlowPathInjections(r, headersBefore) + slog.Debug("credentials injected", + "trace_id", txCtx.TraceID, + "vendor_id", txCtx.VendorID, + "marketplace_id", txCtx.MarketplaceID, + "target_host", targetHost, + "credential_path", "slow", + "injected_header_count", injectedCount, + "plugin_duration_ms", pluginDuration.Milliseconds(), + ) return r, nil } +// applyFastPathCredentials injects the credential headers into the request and +// enriches the context with secret values and injected header keys. +// All operations must happen together to maintain the context enrichment chain: +// header injection → value-based redaction → response-stripping keys → context propagation. +func (s *Server) applyFastPathCredentials(r *http.Request, cred *sdk.Credential) *http.Request { + reqCtx := r.Context() + injectedKeys := make([]string, 0, len(cred.Headers)) + for k, v := range cred.Headers { + r.Header.Set(k, v) + injectedKeys = append(injectedKeys, k) + // Store each credential value in context for value-based log redaction. + // The RedactingHandler will scan all slog string attrs and messages + // for these values. Short values (< MinSecretLength) are automatically + // skipped by the handler to avoid false positives. + reqCtx = observability.WithSecretValue(reqCtx, v) + } + // Store injected header keys so the Reflector can strip them from + // responses (prevents credential reflection for non-standard headers). + reqCtx = security.WithInjectedHeaders(reqCtx, injectedKeys) + // Propagate enriched context. The reverse proxy will Clone this request, + // so resp.Request.Context() in ModifyResponse carries the secret values + // and injected header keys. + return r.WithContext(reqCtx) +} + // detectSlowPathInjections compares the current request headers against // a pre-plugin snapshot to discover what a Slow Path plugin injected. // Any new or modified headers are treated as injected credentials: @@ -612,7 +679,7 @@ func (s *Server) injectCredentials(r *http.Request, txCtx *sdk.TransactionContex // This is a safety net — Slow Path plugins MAY still call // observability.WithSecretValue() and security.WithInjectedHeaders() // for finer control, but forgetting to do so is no longer a security gap. -func (s *Server) detectSlowPathInjections(r *http.Request, before http.Header) *http.Request { +func (s *Server) detectSlowPathInjections(r *http.Request, before http.Header) (modified *http.Request, injectedCount int) { var injectedKeys []string reqCtx := r.Context() @@ -631,7 +698,18 @@ func (s *Server) detectSlowPathInjections(r *http.Request, before http.Header) * r = r.WithContext(reqCtx) } - return r + return r, len(injectedKeys) +} + +// extractTargetHost parses rawURL and returns only the host (with port if present). +// Used in log output to avoid leaking sensitive path or query information. +// Returns an empty string if the URL is invalid or has no host. +func extractTargetHost(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil || u.Host == "" { + return "" + } + return u.Host } // headerValuesEqual returns true if two header value slices are identical. @@ -666,12 +744,19 @@ func (s *Server) forwardRequest(w http.ResponseWriter, r *http.Request, target * proxy.ServeHTTP(w, r) // #nosec G704 -- target validated against allow-list in handleProxy before reaching here } +// StatusClientClosedRequest is a non-standard status code (nginx convention) +// used when the client disconnects before receiving a response. +const StatusClientClosedRequest = 499 + // handlePluginError handles errors from the plugin. -func (s *Server) handlePluginError(w http.ResponseWriter, traceID string, err error) { - // Check for context errors (timeout/cancellation) +func (s *Server) handlePluginError(w http.ResponseWriter, traceID string, txCtx *sdk.TransactionContext, targetHost string, err error) { if errors.Is(err, context.DeadlineExceeded) { slog.Error("plugin timeout", "trace_id", traceID, + "vendor_id", txCtx.VendorID, + "marketplace_id", txCtx.MarketplaceID, + "product_id", txCtx.ProductID, + "target_host", targetHost, "error", err, ) http.Error(w, "Gateway Timeout", http.StatusGatewayTimeout) @@ -679,16 +764,25 @@ func (s *Server) handlePluginError(w http.ResponseWriter, traceID string, err er } if errors.Is(err, context.Canceled) { - // Client disconnected - don't write response slog.Info("client disconnected", "trace_id", traceID, + "vendor_id", txCtx.VendorID, + "marketplace_id", txCtx.MarketplaceID, + "product_id", txCtx.ProductID, + "target_host", targetHost, ) + // Write 499 so RequestLoggerMiddleware logs the correct status instead of + // the default 200. Do NOT write a body — the client is already gone. + w.WriteHeader(StatusClientClosedRequest) return } - // Generic plugin error slog.Error("plugin error", "trace_id", traceID, + "vendor_id", txCtx.VendorID, + "marketplace_id", txCtx.MarketplaceID, + "product_id", txCtx.ProductID, + "target_host", targetHost, "error", err, ) http.Error(w, "Internal Server Error", http.StatusInternalServerError) @@ -759,6 +853,10 @@ func (s *Server) createReverseProxy(target *url.URL, traceID string, txCtx *sdk. slog.Error("proxy error", "trace_id", traceID, + "vendor_id", txCtx.VendorID, + "marketplace_id", txCtx.MarketplaceID, + "product_id", txCtx.ProductID, + "target_host", target.Host, "error", err, ) http.Error(w, "Bad Gateway", http.StatusBadGateway) @@ -799,7 +897,14 @@ func (s *Server) buildModifyResponse(traceID string, txCtx *sdk.TransactionConte var err error action, err = s.config.Plugin.ModifyResponse(ctx, *txCtx, resp) if err != nil { - slog.Warn("plugin ModifyResponse error", "trace_id", traceID, "error", err) + slog.Warn("plugin ModifyResponse error", + "trace_id", traceID, + "vendor_id", txCtx.VendorID, + "marketplace_id", txCtx.MarketplaceID, + "product_id", txCtx.ProductID, + "target_host", resp.Request.URL.Host, + "error", err, + ) // Continue with response processing even if plugin fails } } @@ -814,17 +919,40 @@ func (s *Server) buildModifyResponse(traceID string, txCtx *sdk.TransactionConte security.StripInjectedHeaders(resp.Request.Context(), resp.Header) // Step 3: Core error normalization (safety net - unless plugin opted out) - if action == nil || !action.SkipErrorNormalization { - if err := security.NormalizeError(resp, traceID); err != nil { - slog.Error("error normalization failed", - "trace_id", traceID, - "error", err, - ) - // Continue even if normalization fails - response will be sent as-is - } - } + s.applyErrorNormalization(traceID, txCtx, resp, action) - slog.Info("upstream response", "trace_id", traceID, "status", resp.StatusCode, "content_length", resp.ContentLength) + slog.Info("upstream response", + "trace_id", traceID, + "status", resp.StatusCode, + "content_length", resp.ContentLength, + "vendor_id", txCtx.VendorID, + "marketplace_id", txCtx.MarketplaceID, + "product_id", txCtx.ProductID, + "target_host", resp.Request.URL.Host, + ) return nil } } + +// applyErrorNormalization runs Step 3 of the response modification chain. +// If the plugin opted out via ResponseAction.SkipErrorNormalization, it logs +// the opt-out at DEBUG and skips. Otherwise it runs the core error normalization. +func (s *Server) applyErrorNormalization(traceID string, txCtx *sdk.TransactionContext, resp *http.Response, action *sdk.ResponseAction) { + if action != nil && action.SkipErrorNormalization { + slog.Debug("plugin opted out of error normalization", + "trace_id", traceID, + ) + return + } + if err := security.NormalizeError(resp, traceID); err != nil { + slog.Error("error normalization failed", + "trace_id", traceID, + "vendor_id", txCtx.VendorID, + "marketplace_id", txCtx.MarketplaceID, + "product_id", txCtx.ProductID, + "target_host", resp.Request.URL.Host, + "error", err, + ) + // Continue even if normalization fails - response will be sent as-is + } +} diff --git a/internal/proxy/server_helpers_test.go b/internal/proxy/server_helpers_test.go new file mode 100644 index 0000000..818aba7 --- /dev/null +++ b/internal/proxy/server_helpers_test.go @@ -0,0 +1,59 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package proxy + +import "testing" + +func TestExtractTargetHost(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "returns host only, strips path", + input: "https://api.vendor.com/v1/users/123", + want: "api.vendor.com", + }, + { + name: "returns host only, strips query string", + input: "https://api.vendor.com/v1?api_key=secret", + want: "api.vendor.com", + }, + { + name: "returns host only, strips userinfo", + input: "https://user:pass@api.vendor.com/v1", + want: "api.vendor.com", + }, + { + name: "preserves port", + input: "https://api.vendor.com:8443/v1", + want: "api.vendor.com:8443", + }, + { + name: "strips all sensitive parts together", + input: "https://user:pass@api.vendor.com/v1/users/alice@example.com?token=abc#frag", + want: "api.vendor.com", + }, + { + name: "invalid URL returns empty", + input: "://invalid", + want: "", + }, + { + name: "empty URL returns empty", + input: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractTargetHost(tt.input) + if got != tt.want { + t.Errorf("extractTargetHost(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/proxy/server_test.go b/internal/proxy/server_test.go index 256ce99..db34f1c 100644 --- a/internal/proxy/server_test.go +++ b/internal/proxy/server_test.go @@ -375,7 +375,7 @@ func TestMiddlewareStack_PanicLogsCorrectStatus(t *testing.T) { // Apply middleware in production order: TraceID → Logger → PanicRecovery → handler handler := proxy.PanicRecoveryMiddleware(panicHandler) - handler = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect-Vendor-ID", handler) + handler = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", handler) handler = observability.TraceIDMiddleware("Connect-Request-ID", handler) req := httptest.NewRequest(http.MethodPost, "/test/panic", nil) @@ -420,7 +420,7 @@ func TestMiddlewareStack_NormalRequestLogsCorrectStatus(t *testing.T) { // Apply middleware in production order: TraceID → Logger → PanicRecovery → handler wrapped := proxy.PanicRecoveryMiddleware(handler) - wrapped = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect-Vendor-ID", wrapped) + wrapped = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", wrapped) wrapped = observability.TraceIDMiddleware("Connect-Request-ID", wrapped) req := httptest.NewRequest(http.MethodPost, "/resource", nil) diff --git a/internal/router/middleware.go b/internal/router/middleware.go index 7e26913..c07ad2c 100644 --- a/internal/router/middleware.go +++ b/internal/router/middleware.go @@ -8,6 +8,9 @@ import ( "errors" "log/slog" "net/http" + "net/url" + + "github.com/cloudblue/chaperone/internal/observability" ) // AllowListMiddleware validates incoming requests against the allow list @@ -42,6 +45,7 @@ func (m *AllowListMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) // Missing target URL is a client error if targetURL == "" { slog.Warn("missing target URL header", + "trace_id", observability.TraceIDFromContext(r.Context()), "header", m.headerPrefix+"-Target-URL", "remote_addr", r.RemoteAddr, ) @@ -51,9 +55,9 @@ func (m *AllowListMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) // Validate target URL against allow list if err := m.validator.Validate(targetURL); err != nil { - // Log the validation failure with context - // Note: We log the host but not the full URL to avoid leaking query params slog.Warn("allow list validation failed", + "trace_id", observability.TraceIDFromContext(r.Context()), + "target_host", extractHostFromURL(targetURL), "error", err.Error(), "remote_addr", r.RemoteAddr, ) @@ -73,10 +77,25 @@ func (m *AllowListMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) return } + slog.Debug("allow list validation passed", + "trace_id", observability.TraceIDFromContext(r.Context()), + "target_host", extractHostFromURL(targetURL), + ) + // Validation passed, continue to next handler m.next.ServeHTTP(w, r) } +// extractHostFromURL parses a URL string and returns only the host portion. +// Returns an empty string if the URL is invalid or has no host. +func extractHostFromURL(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil || u.Host == "" { + return "" + } + return u.Host +} + // errorResponse is the JSON structure for error responses. type errorResponse struct { Error string `json:"error"` diff --git a/internal/router/middleware_test.go b/internal/router/middleware_test.go index 225b7df..9250db1 100644 --- a/internal/router/middleware_test.go +++ b/internal/router/middleware_test.go @@ -4,11 +4,15 @@ package router import ( + "bytes" "io" + "log/slog" "net/http" "net/http/httptest" "strings" "testing" + + "github.com/cloudblue/chaperone/internal/observability" ) func TestAllowListMiddleware_ValidRequest(t *testing.T) { @@ -312,6 +316,75 @@ func TestAllowListMiddleware_AllMethods(t *testing.T) { } } +func TestAllowListMiddleware_ValidationPassed_DebugLog(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + prevLogger := slog.Default() + slog.SetDefault(logger) + defer slog.SetDefault(prevLogger) + + allowList := map[string][]string{"api.example.com": {"/**"}} + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + middleware := NewAllowListMiddleware(allowList, "X-Connect", nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/proxy", nil) + req = req.WithContext(observability.WithTraceID(req.Context(), "trace-debug-123")) + req.Header.Set("X-Connect-Target-URL", "https://api.example.com/v1/users") + rr := httptest.NewRecorder() + + middleware.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rr.Code, http.StatusOK) + } + + logOutput := buf.String() + if !strings.Contains(logOutput, `"msg":"allow list validation passed"`) { + t.Errorf("expected allow list validation passed debug log, got: %s", logOutput) + } + if !strings.Contains(logOutput, `"trace_id":"trace-debug-123"`) { + t.Errorf("expected trace_id in log, got: %s", logOutput) + } + if !strings.Contains(logOutput, `"target_host":"api.example.com"`) { + t.Errorf("expected target_host in log, got: %s", logOutput) + } +} + +func TestAllowListMiddleware_ValidationFailed_HasTraceID(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + prevLogger := slog.Default() + slog.SetDefault(logger) + defer slog.SetDefault(prevLogger) + + allowList := map[string][]string{"api.example.com": {"/**"}} + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("next handler should not be called") + }) + middleware := NewAllowListMiddleware(allowList, "X-Connect", nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/proxy", nil) + req = req.WithContext(observability.WithTraceID(req.Context(), "trace-fail-456")) + req.Header.Set("X-Connect-Target-URL", "https://evil.com/data") + rr := httptest.NewRecorder() + + middleware.ServeHTTP(rr, req) + + if rr.Code != http.StatusForbidden { + t.Fatalf("status = %d, want %d", rr.Code, http.StatusForbidden) + } + + logOutput := buf.String() + if !strings.Contains(logOutput, `"trace_id":"trace-fail-456"`) { + t.Errorf("expected trace_id in failure log, got: %s", logOutput) + } + if !strings.Contains(logOutput, `"target_host":"evil.com"`) { + t.Errorf("expected target_host in failure log, got: %s", logOutput) + } +} + func TestAllowListMiddleware_EmptyAllowListDeniesAll(t *testing.T) { // Both nil and empty map should deny all testCases := []struct {