diff --git a/cmd/cmd.go b/cmd/cmd.go index 1733c0b..9766412 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -4,6 +4,8 @@ import ( "encoding/json" "fmt" "net/http" + "os" + "path/filepath" "github.com/fgiudici/headertrace/api" hdrs "github.com/fgiudici/headertrace/pkg/headers" @@ -15,29 +17,41 @@ var ( port string host string headers []string + dropHeaders []string sentHeaders bool + privMode bool printVersion bool ) func init() { + pflag.Usage = func() { + fmt.Fprintf(os.Stderr, "HeaderTrace %s - A simple HTTP server that echoes back received HTTP headers\n\n", getVersion()) + fmt.Fprintf(os.Stderr, "Usage: %s [flags]\n\n", filepath.Base(os.Args[0])) + pflag.PrintDefaults() + } + pflag.StringVarP(&host, "address", "a", "0.0.0.0", "IP address (or domain) to bind to") pflag.StringVarP(&port, "port", "p", "8080", "TCP port to bind to") - pflag.StringSliceVarP(&headers, "header", "H", []string{}, "Custom HTTP headers to add to the HTTP responses (key:value format)") - pflag.BoolVarP(&sentHeaders, "sent", "s", false, "Include the original HTTP headers added to the response in the body") + pflag.StringSliceVarP(&headers, "header", "H", []string{}, "Custom HTTP headers to add to responses (key1:value1,key2:value2)") + pflag.StringSliceVarP(&dropHeaders, "drop-header", "D", []string{}, "HTTP headers to redact from request headers echoed in the response body (key1,key2)") + pflag.BoolVarP(&privMode, "privacy", "P", false, "Drop X-Forwarded and Cloudflare headers from request headers echoed in the response body") + pflag.BoolVarP(&sentHeaders, "sent", "s", false, "Dump the HTTP headers added in the response in the response body") pflag.BoolVarP(&printVersion, "version", "v", false, "Print version and exit") } type server struct { headers map[string]string + dropHeaders []string + privMode bool sentHeaders bool } // Get implements api.ServerInterface func (s *server) Get(w http.ResponseWriter, r *http.Request) { - logging.Infof("Received request: %s", hdrs.RemoteHostInfo(r)) + logging.Infof("Received request: %s", hdrs.GetRemoteHostInfo(r)) // Convert headers to map - headers := hdrs.ToMap(r.Header) + headers := hdrs.ToMap(r.Header, s.dropHeaders, s.privMode) var xHeadersPtr *map[string]string protocol := r.Proto @@ -53,7 +67,7 @@ func (s *server) Get(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) if s.sentHeaders { - xHeaders := hdrs.ToMap(w.Header()) + xHeaders := hdrs.ToMap(w.Header(), nil, false) xHeadersPtr = &xHeaders } @@ -93,7 +107,10 @@ func Execute() error { } // Create server instance - srv := &server{headers: customHeaders, sentHeaders: sentHeaders} + srv := &server{headers: customHeaders, + dropHeaders: dropHeaders, + privMode: privMode, + sentHeaders: sentHeaders} // Create handler from the generated code handler := api.Handler(srv) diff --git a/pkg/headers/headers.go b/pkg/headers/headers.go index 0e45ef1..447fdfe 100644 --- a/pkg/headers/headers.go +++ b/pkg/headers/headers.go @@ -3,7 +3,10 @@ package headers import ( "fmt" "net/http" + "slices" "strings" + + "github.com/fgiudici/headertrace/pkg/logging" ) // Slice2Map takes a slice of header strings in "key:value" format and returns a map. @@ -15,34 +18,88 @@ func SliceToMap(headerStrings []string) (map[string]string, error) { if len(parts) != 2 { return nil, fmt.Errorf("invalid header format '%s', expected 'key:value'", h) } + parts[0] = strings.TrimSpace(parts[0]) if parts[0] == "" { return nil, fmt.Errorf("header key cannot be empty in '%s'", h) } - headers[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) + headers[parts[0]] = strings.TrimSpace(parts[1]) } return headers, nil } // ToMap converts an http.Header to a "key:value" map. -func ToMap(headers http.Header) map[string]string { +// It takes a list of headers to drop and a privacy mode flag to exclude headers that may reveal +// sensitive information of the internal network. Note that enabling debug logging will log all dropped headers. +func ToMap(headers http.Header, dropHeaders []string, privMode bool) map[string]string { headerMap := make(map[string]string) + normalizedDropHeaders := sliceToLower(dropHeaders) + for key, values := range headers { + lowerKey := strings.ToLower(key) + if slices.Contains(normalizedDropHeaders, lowerKey) { + logging.Debugf("Dropping header '%s':'%s'", key, strings.Join(values, ",")) + continue + } + if privMode { + if isCloudflareHeader(lowerKey) || isXForwardedHeader(lowerKey) { + logging.Debugf("Dropping header '%s':'%s' (privacy mode)", key, strings.Join(values, ",")) + continue + } + } headerMap[key] = strings.Join(values, ",") } return headerMap } -func RemoteHostInfo(r *http.Request) string { +func sliceToLower(headers []string) []string { + lower := make([]string, len(headers)) + for i, h := range headers { + lower[i] = strings.ToLower(h) + } + return lower +} + +// isCloudflareHeader checks if a header is a Cloudflare-specific header that should be dropped in privacy mode. +// NOTE: it expects headers to be already normalized to lowercase. +func isCloudflareHeader(header string) bool { + return strings.HasPrefix(header, "cf-") +} + +// isXForwardedHeader checks if a header is an X-Forwarded or X-Real-IP header that should be dropped in privacy mode. +// NOTE: it expects headers to be already normalized to lowercase. +func isXForwardedHeader(header string) bool { + return strings.HasPrefix(header, "x-forwarded-") || header == "x-real-ip" +} + +// GetRemoteHostInfo extracts the remote host information from the request, inspecting common proxy headers like CF-Connecting-IP, X-Real-IP, and X-Forwarded-For. +// It returns a formatted string with the remote address and user agent. +func GetRemoteHostInfo(r *http.Request) string { + // Example of received headers: + // "Accept": "*/*", + // "Accept-Encoding": "gzip", + // "Cdn-Loop": "cloudflare; loops=1", + // "Cf-Connecting-Ip": "1.2.3.4", + // "Cf-Ipcountry": "IT", + // "Cf-Ray": "9cbdc3515d22baf3-MXP", + // "Cf-Visitor": "{\"scheme\":\"http\"}", + // "User-Agent": "curl/7.88.1", + // "X-Forwarded-For": "10.22.0.0", + // "X-Forwarded-Host": "headers.example.com", + // "X-Forwarded-Port": "80", + // "X-Forwarded-Proto": "http", + // "X-Forwarded-Server": "traefik-73f98ac65-z1drx", + // "X-Real-Ip": "10.22.0.0" + remoteAddr := r.RemoteAddr userAgent := r.Header.Get("User-Agent") // Proxied through Cloudflare? if remote := r.Header.Get("CF-Connecting-IP"); remote != "" { - remoteAddr = fmt.Sprintf("%s (%s)", remote, r.Header.Get("Cf-Ipcountry")) + remoteAddr = fmt.Sprintf("%s(%s) [%s]", remote, r.Header.Get("Cf-Ipcountry"), remoteAddr) } else if remote := r.Header.Get("X-Real-Ip"); remote != "" { - remoteAddr = remote + remoteAddr = fmt.Sprintf("%s [%s]", remote, remoteAddr) } else if remote := r.Header.Get("X-Forwarded-For"); remote != "" { - remoteAddr = remote + remoteAddr = fmt.Sprintf("%s [%s]", remote, remoteAddr) } return fmt.Sprintf("%s %q - %s %s %q", remoteAddr, userAgent, r.Method, r.Proto, r.URL.String()) diff --git a/pkg/headers/headers_test.go b/pkg/headers/headers_test.go index 261ea8d..fed1498 100644 --- a/pkg/headers/headers_test.go +++ b/pkg/headers/headers_test.go @@ -1,7 +1,10 @@ package headers import ( + "net/http" + "net/http/httptest" "reflect" + "strings" "testing" ) @@ -57,6 +60,34 @@ func TestSliceToMap(t *testing.T) { input: []string{"Valid:header", "InvalidHeader"}, wantErr: true, }, + { + name: "empty input slice", + input: []string{}, + want: map[string]string{}, + }, + { + name: "whitespace only key and value", + input: []string{" : "}, + wantErr: true, + }, + { + name: "duplicate header keys", + input: []string{"X-Custom:first", "X-Custom:second"}, + want: map[string]string{"X-Custom": "second"}, + }, + { + name: "case sensitive keys", + input: []string{"x-custom:value1", "X-Custom:value2"}, + want: map[string]string{ + "x-custom": "value1", + "X-Custom": "value2", + }, + }, + { + name: "special characters in value", + input: []string{"X-Custom:!@#$%^&*()"}, + want: map[string]string{"X-Custom": "!@#$%^&*()"}, + }, } for _, tt := range tests { @@ -77,3 +108,204 @@ func TestSliceToMap(t *testing.T) { }) } } + +func TestToMap(t *testing.T) { + tests := []struct { + name string + headers http.Header + dropHeaders []string + privMode bool + want map[string]string + }{ + { + name: "basic headers", + headers: http.Header{ + "X-Custom": {"value"}, + }, + dropHeaders: []string{}, + privMode: false, + want: map[string]string{ + "X-Custom": "value", + }, + }, + { + name: "drop headers", + headers: http.Header{ + "X-Custom": {"value"}, + }, + dropHeaders: []string{"x-custom"}, + privMode: false, + want: map[string]string{}, + }, + { + name: "privMode drops Cloudflare headers", + headers: http.Header{ + "CF-Ray": {"9cbdc3515d22baf3-MXP"}, + "Cf-Visitor": {"{\"scheme\":\"https\"}"}, + "cf-Connecting-Ip": {"10.22.0.0"}, + "X-Custom": {"value"}, + }, + dropHeaders: []string{}, + privMode: true, + want: map[string]string{ + "X-Custom": "value", + }, + }, + { + name: "privMode drops X-Forwarded headers", + headers: http.Header{ + "X-Forwarded-For": {"10.22.0.0"}, + "X-forwarded-Host": {"example.com"}, + "x-forwarded-proto": {"https"}, + "X-Custom": {"value"}, + }, + dropHeaders: []string{}, + privMode: true, + want: map[string]string{ + "X-Custom": "value", + }, + }, + { + name: "privMode drops X-Real-IP header", + headers: http.Header{ + "X-Real-Ip": {"10.22.0.0"}, + "x-Real-Ip": {"10.22.0.0"}, + "X-real-ip": {"10.22.0.0"}, + "X-Custom": {"value"}, + }, + dropHeaders: []string{}, + privMode: true, + want: map[string]string{ + "X-Custom": "value", + }, + }, + { + name: "privMode false keeps all headers", + headers: http.Header{ + "CF-Ray": {"9cbdc3515d22baf3-MXP"}, + "X-Forwarded-For": {"10.22.0.0"}, + "X-Real-Ip": {"10.22.0.0"}, + "X-Custom": {"value"}, + }, + dropHeaders: []string{}, + privMode: false, + want: map[string]string{ + "CF-Ray": "9cbdc3515d22baf3-MXP", + "X-Forwarded-For": "10.22.0.0", + "X-Real-Ip": "10.22.0.0", + "X-Custom": "value", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ToMap(tt.headers, tt.dropHeaders, tt.privMode) + if !reflect.DeepEqual(got, tt.want) { + t.Fatalf("ToMap() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetRemoteHostInfo(t *testing.T) { + tests := []struct { + name string + remoteAddr string + headers http.Header + method string + urlString string + expectedIP string + }{ + { + name: "uses CF-Connecting-IP with Cf-Ipcountry", + remoteAddr: "127.0.0.1:5000", + headers: http.Header{ + "CF-Connecting-IP": {"1.2.3.4"}, + "Cf-Ipcountry": {"US"}, + "X-Real-Ip": {"5.6.7.8"}, + "X-Forwarded-For": {"9.10.11.12"}, + }, + method: "GET", + urlString: "http://example.com/", + expectedIP: "1.2.3.4(US)", + }, + { + name: "uses X-Real-IP when CF-Connecting-IP not available", + remoteAddr: "127.0.0.1:5000", + headers: http.Header{ + "X-Real-Ip": {"5.6.7.8"}, + "X-Forwarded-For": {"9.10.11.12"}, + }, + method: "GET", + urlString: "http://example.com/", + expectedIP: "5.6.7.8", + }, + { + name: "uses X-Forwarded-For when CF-Connecting-IP and X-Real-IP not available", + remoteAddr: "127.0.0.1:5000", + headers: http.Header{ + "X-Forwarded-For": {"9.10.11.12"}, + }, + method: "GET", + urlString: "http://example.com/", + expectedIP: "9.10.11.12", + }, + { + name: "uses r.RemoteAddr when no proxy headers available", + remoteAddr: "192.168.1.1:8080", + headers: http.Header{}, + method: "GET", + urlString: "http://example.com/", + expectedIP: "192.168.1.1:8080", + }, + { + name: "prefers cf-Connecting-IP over X-Real-IP", + remoteAddr: "127.0.0.1:5000", + headers: http.Header{ + "cf-Connecting-IP": {"1.2.3.4"}, + "cf-Ipcountry": {"IT"}, + "X-Real-Ip": {"5.6.7.8"}, + "X-Forwarded-For": {"9.10.11.12"}, + }, + method: "GET", + urlString: "http://example.com/", + expectedIP: "1.2.3.4(IT)", + }, + { + name: "prefers x-Real-ip over X-Forwarded-For when CF-Connecting-IP missing", + remoteAddr: "127.0.0.1:5000", + headers: http.Header{ + "x-Real-ip": {"5.6.7.8"}, + "X-Forwarded-For": {"9.10.11.12"}, + }, + method: "GET", + urlString: "http://example.com/", + expectedIP: "5.6.7.8", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.urlString, nil) + req.RemoteAddr = tt.remoteAddr + for key, values := range tt.headers { + for _, value := range values { + req.Header.Add(key, value) + } + } + + got := GetRemoteHostInfo(req) + + // Verify that the expected IP is contained in the result + if !strings.Contains(got, tt.expectedIP) { + t.Fatalf("GetRemoteHostInfo() = %q, expected to contain IP %q", got, tt.expectedIP) + } + + // Verify the format includes method and proto + if !strings.Contains(got, tt.method) { + t.Fatalf("GetRemoteHostInfo() = %q, expected to contain method %q", got, tt.method) + } + }) + } +}