diff --git a/cmd/cmd.go b/cmd/cmd.go index 9766412..23143fa 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -21,6 +21,7 @@ var ( sentHeaders bool privMode bool printVersion bool + logLevel string ) func init() { @@ -37,6 +38,7 @@ func init() { pflag.BoolVarP(&privMode, "privacy", "P", false, "Drop X-Forwarded and Cloudflare headers from request headers echoed in the response body") pflag.BoolVarP(&sentHeaders, "sent", "s", false, "Dump the HTTP headers added in the response in the response body") pflag.BoolVarP(&printVersion, "version", "v", false, "Print version and exit") + pflag.StringVarP(&logLevel, "log-level", "l", "", "Logging level: TRACE, DEBUG, INFO, WARN, ERROR (overrides the LOG_LEVEL env variable)") } type server struct { @@ -67,6 +69,7 @@ func (s *server) Get(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) if s.sentHeaders { + logging.Tracef("Dumping sent headers to response body") xHeaders := hdrs.ToMap(w.Header(), nil, false) xHeadersPtr = &xHeaders } @@ -100,11 +103,32 @@ func Execute() error { fmt.Println(getVersion()) return nil } + + // Init logging level + if logLevel != "" { + os.Setenv("LOG_LEVEL", logLevel) + } + if err := logging.Init(logLevel); err != nil { + logging.Fatalf("Failed to initialize logging: %v", err) + } + + logging.Infof("Starting HeaderTrace version %s", getVersion()) + // Parse custom headers customHeaders, err := hdrs.SliceToMap(headers) if err != nil { - logging.Fatalf("%v", err) + logging.Fatalf("Custom headers: %v", err) } + if len(customHeaders) > 0 { + logging.Debugf("Custom headers to add in responses: %v", customHeaders) + } + + if len(dropHeaders) > 0 { + logging.Debugf("Headers to drop from echoed request headers: %v", dropHeaders) + } + + logging.Debugf("Privacy mode: %v", privMode) + logging.Debugf("Dump sent headers: %v", sentHeaders) // Create server instance srv := &server{headers: customHeaders, diff --git a/main.go b/main.go index 8379e94..ba123e8 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,6 @@ import ( ) func main() { - logging.Init() if err := cmd.Execute(); err != nil { logging.Fatalf("%v", err) } diff --git a/pkg/headers/headers.go b/pkg/headers/headers.go index 447fdfe..5f355a3 100644 --- a/pkg/headers/headers.go +++ b/pkg/headers/headers.go @@ -37,16 +37,17 @@ func ToMap(headers http.Header, dropHeaders []string, privMode bool) map[string] for key, values := range headers { lowerKey := strings.ToLower(key) if slices.Contains(normalizedDropHeaders, lowerKey) { - logging.Debugf("Dropping header '%s':'%s'", key, strings.Join(values, ",")) + logging.Debugf("Redact header '%s':'%s'", key, strings.Join(values, ",")) continue } if privMode { if isCloudflareHeader(lowerKey) || isXForwardedHeader(lowerKey) { - logging.Debugf("Dropping header '%s':'%s' (privacy mode)", key, strings.Join(values, ",")) + logging.Debugf("Redact header '%s':'%s' (privacy mode)", key, strings.Join(values, ",")) continue } } headerMap[key] = strings.Join(values, ",") + logging.Tracef("Dump header '%s':'%s'", key, headerMap[key]) } return headerMap } diff --git a/pkg/logging/logging.go b/pkg/logging/logging.go index 132d573..09fe65f 100644 --- a/pkg/logging/logging.go +++ b/pkg/logging/logging.go @@ -1,6 +1,7 @@ package logging import ( + "fmt" "log" "os" "strings" @@ -9,7 +10,8 @@ import ( type level int const ( - DEBUG level = iota + TRACE level = iota + DEBUG INFO WARN ERROR @@ -18,21 +20,39 @@ const ( var lvl = INFO // Init configures the logger. It reads LOG_LEVEL from the environment -// (one of: DEBUG, INFO, WARN, ERROR) and sets a simple prefix. -func Init() { - if v := os.Getenv("LOG_LEVEL"); v != "" { - switch strings.ToUpper(v) { - case "DEBUG": - lvl = DEBUG - case "INFO": - lvl = INFO - case "WARN": - lvl = WARN - case "ERROR": - lvl = ERROR - } +// (one of: TRACE, DEBUG, INFO, WARN, ERROR) and sets a simple prefix. +func Init(levelStr string) error { + if levelStr == "" { + levelStr = os.Getenv("LOG_LEVEL") } + if levelStr == "" { + levelStr = "INFO" + } + + switch strings.ToUpper(levelStr) { + case "TRACE": + lvl = TRACE + case "DEBUG": + lvl = DEBUG + case "INFO": + lvl = INFO + case "WARN": + lvl = WARN + case "ERROR": + lvl = ERROR + default: + return fmt.Errorf("invalid log level: %s", levelStr) + } + log.SetFlags(log.LstdFlags) + Tracef("loglevel set to %s", strings.ToUpper(levelStr)) + return nil +} + +func Tracef(format string, v ...interface{}) { + if lvl <= TRACE { + log.Printf("TRACE: "+format, v...) + } } func Debugf(format string, v ...interface{}) {