Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions pkg/auth/anonymous_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,32 @@ import (

func TestAnonymousMiddleware(t *testing.T) {
t.Parallel()
// Create a test handler that checks for claims in the context
// Create a test handler that checks for identity in the context
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := GetClaimsFromContext(r.Context())
require.True(t, ok, "Expected claims to be present in context")
identity, ok := IdentityFromContext(r.Context())
require.True(t, ok, "Expected identity to be present in context")
require.NotNil(t, identity, "Expected identity to be non-nil")

// Verify the identity fields
assert.Equal(t, "anonymous", identity.Subject)
assert.Equal(t, "Anonymous User", identity.Name)
assert.Equal(t, "anonymous@localhost", identity.Email)

// Verify the anonymous claims
assert.Equal(t, "anonymous", claims["sub"])
assert.Equal(t, "toolhive-local", claims["iss"])
assert.Equal(t, "toolhive", claims["aud"])
assert.Equal(t, "anonymous@localhost", claims["email"])
assert.Equal(t, "Anonymous User", claims["name"])
require.NotNil(t, identity.Claims)
assert.Equal(t, "anonymous", identity.Claims["sub"])
assert.Equal(t, "toolhive-local", identity.Claims["iss"])
assert.Equal(t, "toolhive", identity.Claims["aud"])
assert.Equal(t, "anonymous@localhost", identity.Claims["email"])
assert.Equal(t, "Anonymous User", identity.Claims["name"])

// Verify timestamps are reasonable
now := time.Now().Unix()
exp, ok := claims["exp"].(int64)
exp, ok := identity.Claims["exp"].(int64)
require.True(t, ok, "Expected exp to be present and be an int64")
assert.Greater(t, exp, now, "Expected exp to be in the future")

iat, ok := claims["iat"].(int64)
iat, ok := identity.Claims["iat"].(int64)
require.True(t, ok, "Expected iat to be present and be an int64")
assert.LessOrEqual(t, iat, now+1, "Expected iat to be current time or earlier (with 1 second tolerance)")

Expand Down
18 changes: 0 additions & 18 deletions pkg/auth/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,6 @@ func IdentityFromContext(ctx context.Context) (*Identity, bool) {
return identity, ok
}

// GetClaimsFromContext retrieves the claims from Identity in the request context.
// This is a helper function for backward compatibility with code that expects MapClaims.
// New code should use IdentityFromContext and access the Claims field directly.
func GetClaimsFromContext(ctx context.Context) (jwt.MapClaims, bool) {
if ctx == nil {
return nil, false
}

// Get Identity and return its Claims
if identity, ok := IdentityFromContext(ctx); ok && identity != nil {
if identity.Claims != nil {
return jwt.MapClaims(identity.Claims), true
}
}

return nil, false
}

// claimsToIdentity converts JWT claims to Identity struct.
// It requires the 'sub' claim per OIDC Core 1.0 spec § 5.1.
// The original token can be provided for passthrough scenarios.
Expand Down
84 changes: 0 additions & 84 deletions pkg/auth/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"testing"

"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -92,89 +91,6 @@ func TestIdentityContext_ExplicitNilValue(t *testing.T) {
assert.Nil(t, identity, "expected nil identity")
}

// TestGetClaimsFromContext_EdgeCases verifies backward-compatible claims retrieval edge cases.
func TestGetClaimsFromContext_EdgeCases(t *testing.T) {
t.Parallel()

tests := []struct {
name string
setupCtx func() context.Context
wantOk bool
checkFunc func(t *testing.T, claims jwt.MapClaims)
}{
{
name: "identity_with_claims",
setupCtx: func() context.Context {
identity := &Identity{
Subject: "user123",
Claims: map[string]any{
"sub": "user123",
"org_id": "org456",
},
}
return WithIdentity(context.Background(), identity)
},
wantOk: true,
checkFunc: func(t *testing.T, claims jwt.MapClaims) {
t.Helper()
assert.Equal(t, "user123", claims["sub"])
assert.Equal(t, "org456", claims["org_id"])
},
},
{
name: "identity_with_nil_claims",
setupCtx: func() context.Context {
identity := &Identity{
Subject: "user123",
Claims: nil,
}
return WithIdentity(context.Background(), identity)
},
wantOk: false,
},
{
name: "no_identity",
setupCtx: func() context.Context {
return context.Background()
},
wantOk: false,
},
{
name: "nil_context",
setupCtx: func() context.Context {
return nil
},
wantOk: false,
},
{
name: "explicitly_nil_identity",
setupCtx: func() context.Context {
return context.WithValue(context.Background(), IdentityContextKey{}, (*Identity)(nil))
},
wantOk: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

ctx := tt.setupCtx()
claims, ok := GetClaimsFromContext(ctx)

assert.Equal(t, tt.wantOk, ok)
if tt.wantOk {
require.NotNil(t, claims)
if tt.checkFunc != nil {
tt.checkFunc(t, claims)
}
} else {
assert.Nil(t, claims)
}
})
}
}

// TestIdentityContext_Overwrite verifies that storing a new identity replaces the old one.
func TestIdentityContext_Overwrite(t *testing.T) {
t.Parallel()
Expand Down
36 changes: 22 additions & 14 deletions pkg/auth/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,32 @@ func TestLocalUserMiddleware(t *testing.T) {
t.Parallel()
username := "testuser"

// Create a test handler that checks for claims in the context
// Create a test handler that checks for identity in the context
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := GetClaimsFromContext(r.Context())
require.True(t, ok, "Expected claims to be present in context")
identity, ok := IdentityFromContext(r.Context())
require.True(t, ok, "Expected identity to be present in context")
require.NotNil(t, identity, "Expected identity to be non-nil")

// Verify the identity fields
assert.Equal(t, username, identity.Subject)
assert.Equal(t, "Local User: "+username, identity.Name)
assert.Equal(t, username+"@localhost", identity.Email)

// Verify the local user claims
assert.Equal(t, username, claims["sub"])
assert.Equal(t, "toolhive-local", claims["iss"])
assert.Equal(t, "toolhive", claims["aud"])
assert.Equal(t, username+"@localhost", claims["email"])
assert.Equal(t, "Local User: "+username, claims["name"])
require.NotNil(t, identity.Claims)
assert.Equal(t, username, identity.Claims["sub"])
assert.Equal(t, "toolhive-local", identity.Claims["iss"])
assert.Equal(t, "toolhive", identity.Claims["aud"])
assert.Equal(t, username+"@localhost", identity.Claims["email"])
assert.Equal(t, "Local User: "+username, identity.Claims["name"])

// Verify timestamps are reasonable
now := time.Now().Unix()
exp, ok := claims["exp"].(int64)
exp, ok := identity.Claims["exp"].(int64)
require.True(t, ok, "Expected exp to be present and be an int64")
assert.Greater(t, exp, now, "Expected exp to be in the future")

iat, ok := claims["iat"].(int64)
iat, ok := identity.Claims["iat"].(int64)
require.True(t, ok, "Expected iat to be present and be an int64")
assert.LessOrEqual(t, iat, now+1, "Expected iat to be current time or earlier (with 1 second tolerance)")

Expand Down Expand Up @@ -63,11 +70,12 @@ func TestLocalUserMiddlewareWithDifferentUsernames(t *testing.T) {
t.Run("username_"+username, func(t *testing.T) {
t.Parallel()
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := GetClaimsFromContext(r.Context())
require.True(t, ok, "Expected claims to be present in context")
identity, ok := IdentityFromContext(r.Context())
require.True(t, ok, "Expected identity to be present in context")
require.NotNil(t, identity, "Expected identity to be non-nil")

assert.Equal(t, username, claims["sub"])
assert.Equal(t, username+"@localhost", claims["email"])
assert.Equal(t, username, identity.Subject)
assert.Equal(t, username+"@localhost", identity.Email)

w.WriteHeader(http.StatusOK)
})
Expand Down
12 changes: 6 additions & 6 deletions pkg/auth/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,17 @@ func TestTokenValidatorMiddleware(t *testing.T) {

// Create a test handler
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get the claims from the context using the helper function
claims, ok := GetClaimsFromContext(r.Context())
if !ok {
t.Errorf("Failed to get claims from context")
http.Error(w, "Failed to get claims from context", http.StatusInternalServerError)
// Get the identity from the context
identity, ok := IdentityFromContext(r.Context())
if !ok || identity == nil {
t.Errorf("Failed to get identity from context")
http.Error(w, "Failed to get identity from context", http.StatusInternalServerError)
return
}

// Write the claims as the response
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(claims); err != nil {
if err := json.NewEncoder(w).Encode(identity.Claims); err != nil {
t.Errorf("Failed to encode claims: %v", err)
http.Error(w, fmt.Sprintf("Failed to encode claims: %v", err), http.StatusInternalServerError)
return
Expand Down
100 changes: 5 additions & 95 deletions pkg/auth/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"net/http/httptest"
"testing"

"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -105,97 +104,6 @@ func TestExtractBearerToken(t *testing.T) {
}
}

func TestGetClaimsFromContext(t *testing.T) {
t.Parallel()
// Test with claims in context
claims := jwt.MapClaims{
"sub": "testuser",
"iss": "test-issuer",
"aud": "test-audience",
}
identity := &Identity{Subject: "testuser", Claims: claims}
ctx := WithIdentity(context.Background(), identity)

retrievedClaims, ok := GetClaimsFromContext(ctx)
require.True(t, ok, "Expected to retrieve claims from context")
assert.Equal(t, "testuser", retrievedClaims["sub"])
assert.Equal(t, "test-issuer", retrievedClaims["iss"])

// Test with no identity in context
emptyCtx := context.Background()
_, ok = GetClaimsFromContext(emptyCtx)
assert.False(t, ok, "Expected no claims to be found in empty context")

// Test with identity that has nil claims
identityWithNilClaims := &Identity{Subject: "testuser", Claims: nil}
ctxWithNilClaims := WithIdentity(context.Background(), identityWithNilClaims)
_, ok = GetClaimsFromContext(ctxWithNilClaims)
assert.False(t, ok, "Expected no claims to be found when identity has nil claims")

// Test with nil context - we intentionally pass nil to test the nil check
//nolint:staticcheck // SA1012: Testing nil context handling is intentional
_, ok = GetClaimsFromContext(nil)
assert.False(t, ok, "Expected no claims to be found with nil context")
}

func TestGetClaimsFromContextWithDifferentClaimTypes(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
claims jwt.MapClaims
expected map[string]interface{}
}{
{
name: "string_claims",
claims: jwt.MapClaims{
"sub": "user123",
"email": "user@example.com",
"name": "Test User",
},
expected: map[string]interface{}{
"sub": "user123",
"email": "user@example.com",
"name": "Test User",
},
},
{
name: "mixed_claims",
claims: jwt.MapClaims{
"sub": "user123",
"exp": int64(1234567890),
"iat": int64(1234567800),
"admin": true,
},
expected: map[string]interface{}{
"sub": "user123",
"exp": int64(1234567890),
"iat": int64(1234567800),
"admin": true,
},
},
{
name: "empty_claims",
claims: jwt.MapClaims{},
expected: map[string]interface{}{},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
identity := &Identity{Subject: "test-user", Claims: tc.claims}
ctx := WithIdentity(context.Background(), identity)
retrievedClaims, ok := GetClaimsFromContext(ctx)

require.True(t, ok, "Expected to retrieve claims from context")

for key, expectedValue := range tc.expected {
assert.Equal(t, expectedValue, retrievedClaims[key], "Expected %s to be %v, got %v", key, expectedValue, retrievedClaims[key])
}
})
}
}

func TestGetAuthenticationMiddleware(t *testing.T) {
t.Parallel()
// Initialize logger for testing
Expand All @@ -210,9 +118,11 @@ func TestGetAuthenticationMiddleware(t *testing.T) {

// Test that the middleware works by creating a test handler
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := GetClaimsFromContext(r.Context())
require.True(t, ok, "Expected claims to be present in context")
assert.Equal(t, "toolhive-local", claims["iss"])
identity, ok := IdentityFromContext(r.Context())
require.True(t, ok, "Expected identity to be present in context")
require.NotNil(t, identity, "Expected identity to be non-nil")
require.NotNil(t, identity.Claims, "Expected claims to be present")
assert.Equal(t, "toolhive-local", identity.Claims["iss"])
w.WriteHeader(http.StatusOK)
})

Expand Down
Loading