From 24d70507e47ae58c6728525c4354c7f9eeea102f Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Tue, 14 Apr 2026 13:48:46 +1000 Subject: [PATCH] feat: optional propagation of headers into logs This is useful for basic tracing. Co-authored-by: Claude Code --- cmd/cachewd/main.go | 21 ++++-- internal/httputil/logging.go | 17 ----- internal/logging/logging.go | 33 ++++++++- internal/logging/middleware_test.go | 108 ++++++++++++++++++++++++++++ 4 files changed, 155 insertions(+), 24 deletions(-) delete mode 100644 internal/httputil/logging.go create mode 100644 internal/logging/middleware_test.go diff --git a/cmd/cachewd/main.go b/cmd/cachewd/main.go index de8d6025..6f20a800 100644 --- a/cmd/cachewd/main.go +++ b/cmd/cachewd/main.go @@ -23,7 +23,6 @@ import ( "github.com/block/cachew/internal/config" "github.com/block/cachew/internal/gitclone" "github.com/block/cachew/internal/githubapp" - "github.com/block/cachew/internal/httputil" "github.com/block/cachew/internal/jobscheduler" "github.com/block/cachew/internal/logging" "github.com/block/cachew/internal/metadatadb" @@ -108,7 +107,14 @@ func main() { logger.InfoContext(ctx, "Starting cachewd", "bind", globalConfig.Bind) - server, err := newServer(ctx, mux, globalConfig.Bind, globalConfig.MetricsConfig, globalConfig.OPAConfig) + server, err := newServer( + ctx, + mux, + globalConfig.Bind, + globalConfig.MetricsConfig, + globalConfig.OPAConfig, + globalConfig.LoggingConfig, + ) fatalIfError(ctx, logger, err, "Failed to create server") err = server.ListenAndServe() @@ -220,7 +226,14 @@ func extractPathPrefix(path string) string { return prefix } -func newServer(ctx context.Context, muxHandler http.Handler, bind string, metricsConfig metrics.Config, opaConfig opa.Config) (*http.Server, error) { +func newServer( + ctx context.Context, + muxHandler http.Handler, + bind string, + metricsConfig metrics.Config, + opaConfig opa.Config, + logConfig logging.Config, +) (*http.Server, error) { var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { labeler, _ := otelhttp.LabelerFromContext(r.Context()) labeler.Add(attribute.String("cachew.http.path.prefix", extractPathPrefix(r.URL.Path))) @@ -238,7 +251,7 @@ func newServer(ctx context.Context, muxHandler http.Handler, bind string, metric otelhttp.WithTracerProvider(otel.GetTracerProvider()), )(handler) - handler = httputil.LoggingMiddleware(handler) + handler = logging.Middleware(handler, logConfig) logger := logging.FromContext(ctx) return &http.Server{ diff --git a/internal/httputil/logging.go b/internal/httputil/logging.go deleted file mode 100644 index aad2d1e1..00000000 --- a/internal/httputil/logging.go +++ /dev/null @@ -1,17 +0,0 @@ -package httputil - -import ( - "net/http" - - "github.com/block/cachew/internal/logging" -) - -func LoggingMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Propagate attributes tot the handlers. - logger := logging.FromContext(r.Context()).With("method", r.Method, "uri", r.RequestURI) - r = r.WithContext(logging.ContextWithLogger(r.Context(), logger)) - logger.Debug("Request received") - next.ServeHTTP(w, r) - }) -} diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 20a17968..a854b1f0 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -4,15 +4,42 @@ package logging import ( "context" "log/slog" + "net/http" "os" + "time" "github.com/lmittmann/tint" ) type Config struct { - JSON bool `hcl:"json,optional" help:"Enable JSON logging."` - Level slog.Level `hcl:"level" help:"Set the logging level." default:"info"` - Remap map[string]string `hcl:"remap,optional" help:"Remap field names from old to new (e.g., msg=message, time=timestamp)."` + JSON bool `hcl:"json,optional" help:"Enable JSON logging."` + Level slog.Level `hcl:"level" help:"Set the logging level." default:"info"` + Remap map[string]string `hcl:"remap,optional" help:"Remap field names from old to new (e.g., msg=message, time=timestamp)."` + Headers map[string]string `hcl:"headers,optional" help:"Propagate these inbound request headers to the given log attribute."` +} + +// Middleware returns an HTTP middleware that logs incoming requests and attaches +// any configured headers as log attributes. +func Middleware(next http.Handler, config Config) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + // Propagate attributes tot the handlers. + logger := FromContext(ctx).With("method", r.Method, "uri", r.RequestURI) + start := time.Now() + logger.Debug("Request received") + var attrs []any + for header, attr := range config.Headers { + if h := r.Header.Get(header); h != "" { + attrs = append(attrs, slog.String(attr, h)) + } + } + if len(attrs) > 0 { + logger = logger.With(attrs...) + r = r.WithContext(ContextWithLogger(ctx, logger)) + } + next.ServeHTTP(w, r) + logger.Debug("Request complete", "elapsed", time.Since(start)) + }) } var levelVar = &slog.LevelVar{} //nolint:gochecknoglobals diff --git a/internal/logging/middleware_test.go b/internal/logging/middleware_test.go new file mode 100644 index 00000000..e1f13a12 --- /dev/null +++ b/internal/logging/middleware_test.go @@ -0,0 +1,108 @@ +package logging //nolint:testpackage + +import ( + "bytes" + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/alecthomas/assert/v2" +) + +func TestMiddleware(t *testing.T) { + tests := []struct { + name string + config Config + headers map[string]string + wantAttrs map[string]string + wantAbsent []string + }{ + { + name: "NoHeadersConfigured", + config: Config{}, + }, + { + name: "HeaderPresent", + config: Config{Headers: map[string]string{"X-Request-ID": "request_id"}}, + headers: map[string]string{ + "X-Request-ID": "abc-123", + }, + wantAttrs: map[string]string{"request_id": "abc-123"}, + }, + { + name: "HeaderMissing", + config: Config{Headers: map[string]string{"X-Request-ID": "request_id"}}, + wantAbsent: []string{"request_id"}, + }, + { + name: "MixedPresentAndMissing", + config: Config{Headers: map[string]string{ + "X-Request-ID": "request_id", + "X-Trace-ID": "trace_id", + }}, + headers: map[string]string{ + "X-Request-ID": "abc-123", + }, + wantAttrs: map[string]string{"request_id": "abc-123"}, + wantAbsent: []string{"trace_id"}, + }, + { + name: "MultipleHeadersPresent", + config: Config{Headers: map[string]string{ + "X-Request-ID": "request_id", + "X-Trace-ID": "trace_id", + }}, + headers: map[string]string{ + "X-Request-ID": "abc-123", + "X-Trace-ID": "def-456", + }, + wantAttrs: map[string]string{ + "request_id": "abc-123", + "trace_id": "def-456", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { + if a.Key == slog.TimeKey { + return slog.Attr{} + } + return a + }, + })) + + inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + FromContext(r.Context()).Info("test") + }) + + handler := Middleware(inner, tt.config) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req = req.WithContext(ContextWithLogger(req.Context(), logger)) + for k, v := range tt.headers { + req.Header.Set(k, v) + } + + handler.ServeHTTP(httptest.NewRecorder(), req) + + var entry map[string]any + assert.NoError(t, json.Unmarshal(buf.Bytes(), &entry)) + + for attr, want := range tt.wantAttrs { + got, ok := entry[attr].(string) + assert.True(t, ok, "expected attribute %q to be a string", attr) + assert.Equal(t, want, got) + } + for _, attr := range tt.wantAbsent { + _, present := entry[attr] + assert.False(t, present, "expected attribute %q to be absent", attr) + } + }) + } +}