Skip to content
Draft
33 changes: 27 additions & 6 deletions internal/observability/request_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package observability
import (
"log/slog"
"net/http"
"net/url"
"strings"
"time"
)
Expand Down Expand Up @@ -83,9 +84,8 @@ func (rc *ResponseCapturer) Status() int {
//
// Parameters:
// - logger: structured logger for output
// - vendorHeader: the full header name for vendor ID (e.g., "X-Connect-Vendor-ID")
// constructed by the caller using the configured prefix, so header name
// derivation stays in one place alongside ParseContext (DRY / ADR-005)
// - headerPrefix: the context header prefix (e.g., "X-Connect"). The middleware
// constructs header names internally using the stable suffix constants.
// - next: the downstream handler
//
// Fields emitted:
Expand All @@ -94,7 +94,10 @@ func (rc *ResponseCapturer) Status() int {
// - path: Request path
// - status: Response status code
// - latency_ms: Time to process request in milliseconds
// - vendor_id: Vendor ID from the vendorHeader request header
// - vendor_id: Vendor ID from the <prefix>-Vendor-ID request header
// - marketplace_id: Marketplace ID from the <prefix>-Marketplace-ID request header
// - product_id: Product ID from the <prefix>-Product-ID request header
// - target_host: Host extracted from the <prefix>-Target-URL request header
// - client_ip: Client IP from proxy headers (X-Forwarded-For > X-Real-IP);
// empty when no proxy headers are present (use remote_addr instead)
// - remote_addr: Raw TCP peer address (always r.RemoteAddr, useful for
Expand All @@ -104,7 +107,12 @@ func (rc *ResponseCapturer) Status() int {
// already in the request context when this handler receives the request.
// Uses defer to ensure logging occurs even if downstream handlers panic
// (when used with panic recovery middleware).
func RequestLoggerMiddleware(logger *slog.Logger, vendorHeader string, next http.Handler) http.Handler {
func RequestLoggerMiddleware(logger *slog.Logger, headerPrefix string, next http.Handler) http.Handler {
vendorHdr := headerPrefix + "-Vendor-ID"
marketplaceHdr := headerPrefix + "-Marketplace-ID"
productHdr := headerPrefix + "-Product-ID"
targetURLHdr := headerPrefix + "-Target-URL"

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
ctx := r.Context()
Expand All @@ -120,7 +128,10 @@ func RequestLoggerMiddleware(logger *slog.Logger, vendorHeader string, next http
"path", r.URL.Path,
"status", capturer.Status(),
"latency_ms", time.Since(start).Milliseconds(),
"vendor_id", r.Header.Get(vendorHeader),
"vendor_id", r.Header.Get(vendorHdr),
"marketplace_id", r.Header.Get(marketplaceHdr),
"product_id", r.Header.Get(productHdr),
"target_host", extractHost(r.Header.Get(targetURLHdr)),
"client_ip", ClientIP(r),
"remote_addr", r.RemoteAddr,
)
Expand All @@ -130,6 +141,16 @@ func RequestLoggerMiddleware(logger *slog.Logger, vendorHeader string, next http
})
}

// extractHost parses rawURL and returns only the host (with port if present).
// Returns an empty string if the URL is invalid or has no host.
func extractHost(rawURL string) string {
u, err := url.Parse(rawURL)
if err != nil || u.Host == "" {
return ""
}
return u.Host
}

// ClientIP extracts the client IP from proxy headers only.
// Returns the first IP from X-Forwarded-For, or X-Real-IP as fallback.
// Returns "" when no proxy headers are present — in that case the
Expand Down
132 changes: 114 additions & 18 deletions internal/observability/request_logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,26 @@ import (
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

// logEntry represents a parsed JSON log line for test assertions.
type logEntry struct {
Time string `json:"time"`
Level string `json:"level"`
Msg string `json:"msg"`
TraceID string `json:"trace_id"`
Method string `json:"method"`
Path string `json:"path"`
Status int `json:"status"`
LatencyMs int64 `json:"latency_ms"`
VendorID string `json:"vendor_id"`
ClientIP string `json:"client_ip"`
RemoteAddr string `json:"remote_addr"`
Time string `json:"time"`
Level string `json:"level"`
Msg string `json:"msg"`
TraceID string `json:"trace_id"`
Method string `json:"method"`
Path string `json:"path"`
Status int `json:"status"`
LatencyMs int64 `json:"latency_ms"`
VendorID string `json:"vendor_id"`
MarketplaceID string `json:"marketplace_id"`
ProductID string `json:"product_id"`
TargetHost string `json:"target_host"`
ClientIP string `json:"client_ip"`
RemoteAddr string `json:"remote_addr"`
}

// parseLogEntry parses a single JSON log line from the buffer.
Expand Down Expand Up @@ -130,10 +134,13 @@ func TestRequestLoggerMiddleware_LogsRequestFields(t *testing.T) {
w.WriteHeader(http.StatusOK)
})

handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", inner)
handler := RequestLoggerMiddleware(logger, "X-Connect", inner)
r := httptest.NewRequest(http.MethodPost, "/proxy", nil)
r = r.WithContext(WithTraceID(r.Context(), "test-trace-123"))
r.Header.Set("X-Connect-Vendor-ID", "microsoft")
r.Header.Set("X-Connect-Marketplace-ID", "MP-US")
r.Header.Set("X-Connect-Product-ID", "PRD-001")
r.Header.Set("X-Connect-Target-URL", "https://graph.microsoft.com/v1.0/users")
r.RemoteAddr = "10.0.0.1:54321"
w := httptest.NewRecorder()

Expand Down Expand Up @@ -161,6 +168,15 @@ func TestRequestLoggerMiddleware_LogsRequestFields(t *testing.T) {
if entry.VendorID != "microsoft" {
t.Errorf("vendor_id = %q, want %q", entry.VendorID, "microsoft")
}
if entry.MarketplaceID != "MP-US" {
t.Errorf("marketplace_id = %q, want %q", entry.MarketplaceID, "MP-US")
}
if entry.ProductID != "PRD-001" {
t.Errorf("product_id = %q, want %q", entry.ProductID, "PRD-001")
}
if entry.TargetHost != "graph.microsoft.com" {
t.Errorf("target_host = %q, want %q", entry.TargetHost, "graph.microsoft.com")
}
if entry.ClientIP != "" {
t.Errorf("client_ip = %q, want empty (no proxy headers)", entry.ClientIP)
}
Expand All @@ -177,7 +193,7 @@ func TestRequestLoggerMiddleware_CapturesErrorStatus(t *testing.T) {
w.WriteHeader(http.StatusBadGateway)
})

handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", inner)
handler := RequestLoggerMiddleware(logger, "X-Connect", inner)
r := httptest.NewRequest(http.MethodGet, "/proxy", nil)
r = r.WithContext(WithTraceID(r.Context(), "err-trace"))
w := httptest.NewRecorder()
Expand All @@ -198,7 +214,7 @@ func TestRequestLoggerMiddleware_NoTraceID_LogsEmpty(t *testing.T) {
w.WriteHeader(http.StatusOK)
})

handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", inner)
handler := RequestLoggerMiddleware(logger, "X-Connect", inner)
r := httptest.NewRequest(http.MethodGet, "/proxy", nil)
// Deliberately no trace ID in context
w := httptest.NewRecorder()
Expand All @@ -220,7 +236,7 @@ func TestRequestLoggerMiddleware_LogsOnPanic(t *testing.T) {
})

// Wrap with panic recovery INSIDE request logger, so logger still fires
handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", panicRecoveryForTest(panicky))
handler := RequestLoggerMiddleware(logger, "X-Connect", panicRecoveryForTest(panicky))
r := httptest.NewRequest(http.MethodGet, "/proxy", nil)
r = r.WithContext(WithTraceID(r.Context(), "panic-trace"))
w := httptest.NewRecorder()
Expand All @@ -244,7 +260,7 @@ func TestRequestLoggerMiddleware_ClientIPFromXForwardedFor(t *testing.T) {
w.WriteHeader(http.StatusOK)
})

handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", inner)
handler := RequestLoggerMiddleware(logger, "X-Connect", inner)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("X-Forwarded-For", "203.0.113.50")
w := httptest.NewRecorder()
Expand All @@ -269,7 +285,7 @@ func TestRequestLoggerMiddleware_TraceIDFromOuterMiddleware(t *testing.T) {
})

// Production order: TraceID (outermost) → Logger → handler
handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", inner)
handler := RequestLoggerMiddleware(logger, "X-Connect", inner)
handler = TraceIDMiddleware("X-Trace-ID", handler)

r := httptest.NewRequest(http.MethodGet, "/test", nil)
Expand All @@ -295,7 +311,7 @@ func TestRequestLoggerMiddleware_TraceIDGenerated(t *testing.T) {
})

// Production order: TraceID (outermost) → Logger → handler
handler := RequestLoggerMiddleware(logger, "X-Connect-Vendor-ID", inner)
handler := RequestLoggerMiddleware(logger, "X-Connect", inner)
handler = TraceIDMiddleware("X-Trace-ID", handler)

r := httptest.NewRequest(http.MethodGet, "/test", nil)
Expand All @@ -310,6 +326,86 @@ func TestRequestLoggerMiddleware_TraceIDGenerated(t *testing.T) {
}
}

func TestRequestLoggerMiddleware_LogsTargetHost_ExtractsHostOnly(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug}))

inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

handler := RequestLoggerMiddleware(logger, "X-Connect", inner)
r := httptest.NewRequest(http.MethodGet, "/proxy", nil)
// URL with query string and path — only the host should appear in the log
r.Header.Set("X-Connect-Target-URL", "https://api.vendor.com/v1/users?api_key=secret&token=abc")
w := httptest.NewRecorder()

handler.ServeHTTP(w, r)

entry := parseLogEntry(t, buf.Bytes())
if entry.TargetHost != "api.vendor.com" {
t.Errorf("target_host = %q, want %q (only host, no path/query)", entry.TargetHost, "api.vendor.com")
}
// Verify query params did not leak into any logged field
logOutput := buf.String()
if strings.Contains(logOutput, "api_key") || strings.Contains(logOutput, "secret") {
t.Errorf("log should not contain query params, got: %s", logOutput)
}
}

func TestExtractHost(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "full URL with path",
input: "https://api.vendor.com/v1/users",
want: "api.vendor.com",
},
{
name: "URL with port",
input: "https://api.vendor.com:8443/v1",
want: "api.vendor.com:8443",
},
{
name: "URL with query string",
input: "https://api.vendor.com/v1?key=secret",
want: "api.vendor.com",
},
{
name: "URL without path",
input: "https://api.vendor.com",
want: "api.vendor.com",
},
{
name: "empty URL returns empty",
input: "",
want: "",
},
{
name: "invalid URL returns empty",
input: "://invalid",
want: "",
},
{
name: "path-only URL returns empty",
input: "/just/a/path",
want: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractHost(tt.input)
if got != tt.want {
t.Errorf("extractHost(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}

// panicRecoveryForTest wraps a handler with basic panic recovery for testing.
func panicRecoveryForTest(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading
Loading