Skip to content

Commit 6a71fdb

Browse files
authored
Consolidate bearer token extraction into utility (#2410)
Extract bearer token extraction logic from three different locations into a single, utility function to avoid code duplication.
1 parent 22a02cd commit 6a71fdb

File tree

4 files changed

+146
-25
lines changed

4 files changed

+146
-25
lines changed

pkg/auth/token.go

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -867,24 +867,14 @@ func (v *TokenValidator) buildWWWAuthenticate(includeError bool, errDescription
867867
// Middleware creates an HTTP middleware that validates JWT tokens.
868868
func (v *TokenValidator) Middleware(next http.Handler) http.Handler {
869869
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
870-
// Get the token from the Authorization header
871-
authHeader := r.Header.Get("Authorization")
872-
if authHeader == "" {
873-
w.Header().Set("WWW-Authenticate", v.buildWWWAuthenticate(false, ""))
874-
http.Error(w, "Authorization header required", http.StatusUnauthorized)
875-
return
876-
}
877-
878-
// Check if the Authorization header has the Bearer prefix
879-
if !strings.HasPrefix(authHeader, "Bearer ") {
870+
// Extract the bearer token from the Authorization header
871+
tokenString, err := ExtractBearerToken(r)
872+
if err != nil {
880873
w.Header().Set("WWW-Authenticate", v.buildWWWAuthenticate(false, ""))
881-
http.Error(w, "Invalid Authorization header format", http.StatusUnauthorized)
874+
http.Error(w, err.Error(), http.StatusUnauthorized)
882875
return
883876
}
884877

885-
// Extract the token
886-
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
887-
888878
// Validate the token
889879
claims, err := v.ValidateToken(r.Context(), tokenString)
890880
if err != nil {

pkg/auth/tokenexchange/middleware.go

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"fmt"
77
"net/http"
88
"os"
9-
"strings"
109

1110
"github.com/golang-jwt/jwt/v5"
1211
"golang.org/x/oauth2"
@@ -313,16 +312,9 @@ func createTokenExchangeMiddleware(
313312
tokenProvider = subjectTokenProvider
314313
} else {
315314
// otherwise, extract token from incoming request's Authorization header
316-
authHeader := r.Header.Get("Authorization")
317-
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
318-
logger.Debug("No valid Bearer token found, proceeding without token exchange")
319-
next.ServeHTTP(w, r)
320-
return
321-
}
322-
323-
subjectToken := strings.TrimPrefix(authHeader, "Bearer ")
324-
if subjectToken == "" {
325-
logger.Debug("Empty Bearer token, proceeding without token exchange")
315+
subjectToken, err := auth.ExtractBearerToken(r)
316+
if err != nil {
317+
logger.Debugf("No valid Bearer token found (%v), proceeding without token exchange", err)
326318
next.ServeHTTP(w, r)
327319
return
328320
}

pkg/auth/utils.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package auth
33

44
import (
55
"context"
6+
"errors"
67
"net/http"
78
"os/user"
89
"strings"
@@ -12,6 +13,52 @@ import (
1213
"github.com/stacklok/toolhive/pkg/logger"
1314
)
1415

16+
// bearerTokenType defines the expected token type for Bearer authentication.
17+
const bearerTokenType = "Bearer"
18+
19+
// Common Bearer token extraction errors
20+
var (
21+
ErrAuthHeaderMissing = errors.New("authorization header required")
22+
ErrInvalidAuthHeaderFormat = errors.New("invalid authorization header format, expected 'Bearer <token>'")
23+
ErrEmptyBearerToken = errors.New("empty token in authorization header")
24+
)
25+
26+
// ExtractBearerToken extracts and validates a Bearer token from the Authorization header.
27+
// It performs the following validations:
28+
// 1. Verifies the Authorization header is present
29+
// 2. Checks for the "Bearer " prefix (case-sensitive per RFC 6750)
30+
// 3. Ensures the token is not empty after removing the prefix
31+
//
32+
// The function returns the token string (without "Bearer " prefix) and any validation error.
33+
// Callers are responsible for further token validation (JWT parsing, introspection, etc.)
34+
// and for converting errors to appropriate HTTP responses.
35+
//
36+
// This function implements RFC 6750 Section 2.1 (Bearer Token Authorization Header).
37+
// See: https://datatracker.ietf.org/doc/html/rfc6750#section-2.1
38+
func ExtractBearerToken(r *http.Request) (string, error) {
39+
// Get the Authorization header
40+
authHeader := r.Header.Get("Authorization")
41+
if authHeader == "" {
42+
return "", ErrAuthHeaderMissing
43+
}
44+
45+
// Check for the Bearer prefix (case-sensitive per RFC 6750)
46+
bearerPrefix := bearerTokenType + " "
47+
if !strings.HasPrefix(authHeader, bearerPrefix) {
48+
return "", ErrInvalidAuthHeaderFormat
49+
}
50+
51+
// Extract the token
52+
tokenString := strings.TrimPrefix(authHeader, bearerPrefix)
53+
54+
// Check for empty token (handles "Bearer " with no token or only whitespace)
55+
if strings.TrimSpace(tokenString) == "" {
56+
return "", ErrEmptyBearerToken
57+
}
58+
59+
return tokenString, nil
60+
}
61+
1562
// GetClaimsFromContext retrieves the claims from the request context.
1663
// This is a helper function that can be used by authorization policies
1764
// to access the claims regardless of which middleware was used (JWT, anonymous, or local).

pkg/auth/utils_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package auth
22

33
import (
44
"context"
5+
"errors"
56
"net/http"
67
"net/http/httptest"
78
"testing"
@@ -13,6 +14,97 @@ import (
1314
"github.com/stacklok/toolhive/pkg/logger"
1415
)
1516

17+
func TestExtractBearerToken(t *testing.T) {
18+
t.Parallel()
19+
20+
testCases := []struct {
21+
name string
22+
authHeader string
23+
expectedToken string
24+
expectedError error
25+
}{
26+
{
27+
name: "valid_bearer_token",
28+
authHeader: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
29+
expectedToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
30+
expectedError: nil,
31+
},
32+
{
33+
name: "missing_authorization_header",
34+
authHeader: "",
35+
expectedToken: "",
36+
expectedError: ErrAuthHeaderMissing,
37+
},
38+
{
39+
name: "invalid_format_no_bearer_prefix",
40+
authHeader: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
41+
expectedToken: "",
42+
expectedError: ErrInvalidAuthHeaderFormat,
43+
},
44+
{
45+
name: "lowercase_bearer",
46+
authHeader: "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
47+
expectedToken: "",
48+
expectedError: ErrInvalidAuthHeaderFormat,
49+
},
50+
{
51+
name: "empty_token_after_prefix",
52+
authHeader: "Bearer ",
53+
expectedToken: "",
54+
expectedError: ErrEmptyBearerToken,
55+
},
56+
{
57+
name: "empty_token_with_trailing_spaces",
58+
authHeader: "Bearer ",
59+
expectedToken: "",
60+
expectedError: ErrEmptyBearerToken,
61+
},
62+
{
63+
name: "token_with_spaces_valid_per_rfc",
64+
authHeader: "Bearer token with spaces",
65+
expectedToken: "token with spaces",
66+
expectedError: nil,
67+
},
68+
{
69+
name: "basic_auth_instead_of_bearer",
70+
authHeader: "Basic dXNlcjpwYXNz",
71+
expectedToken: "",
72+
expectedError: ErrInvalidAuthHeaderFormat,
73+
},
74+
{
75+
name: "token_with_special_characters",
76+
authHeader: "Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0In0.abc-def_ghi",
77+
expectedToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0In0.abc-def_ghi",
78+
expectedError: nil,
79+
},
80+
}
81+
82+
for _, tc := range testCases {
83+
t.Run(tc.name, func(t *testing.T) {
84+
t.Parallel()
85+
86+
// Create a test request with the authorization header
87+
req := httptest.NewRequest("GET", "/test", nil)
88+
if tc.authHeader != "" {
89+
req.Header.Set("Authorization", tc.authHeader)
90+
}
91+
92+
// Extract the bearer token
93+
token, err := ExtractBearerToken(req)
94+
95+
// Check the error
96+
if tc.expectedError != nil {
97+
require.Error(t, err)
98+
assert.True(t, errors.Is(err, tc.expectedError), "Expected error %v, got %v", tc.expectedError, err)
99+
assert.Empty(t, token)
100+
} else {
101+
require.NoError(t, err)
102+
assert.Equal(t, tc.expectedToken, token)
103+
}
104+
})
105+
}
106+
}
107+
16108
func TestGetClaimsFromContext(t *testing.T) {
17109
t.Parallel()
18110
// Test with claims in context

0 commit comments

Comments
 (0)