From ed3b079370a0b2e6ea8325a80aeba5193d79cd50 Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Tue, 4 Nov 2025 08:17:31 +0200 Subject: [PATCH] Remove GetClaimsFromContext backward compatibility helper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove GetClaimsFromContext function and migrate all usages to IdentityFromContext pattern following the authentication unification completed in d32f2bf5. Changes: - Remove GetClaimsFromContext() from pkg/auth/context.go - Update all test files to use IdentityFromContext() and access identity.Claims directly when needed - Remove dedicated GetClaimsFromContext test functions - Remove unused jwt imports from test files Rationale: GetClaimsFromContext was added as a backward-compatibility helper during the Identity struct unification. All production code has migrated to using IdentityFromContext directly, with zero production usages remaining. Tests should verify the actual production API contract. Impact: - Removes ~130 lines of code (function + tests) - Simplifies the auth API to a single pattern - All tests pass with improved clarity 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- pkg/auth/anonymous_test.go | 27 ++++++---- pkg/auth/context.go | 18 ------- pkg/auth/context_test.go | 84 ------------------------------- pkg/auth/local_test.go | 36 +++++++------ pkg/auth/token_test.go | 12 ++--- pkg/auth/utils_test.go | 100 ++----------------------------------- 6 files changed, 50 insertions(+), 227 deletions(-) diff --git a/pkg/auth/anonymous_test.go b/pkg/auth/anonymous_test.go index cc095d1b2..da4943998 100644 --- a/pkg/auth/anonymous_test.go +++ b/pkg/auth/anonymous_test.go @@ -12,25 +12,32 @@ import ( func TestAnonymousMiddleware(t *testing.T) { t.Parallel() - // Create a test handler that checks for claims in the context + // Create a test handler that checks for identity in the context testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - claims, ok := GetClaimsFromContext(r.Context()) - require.True(t, ok, "Expected claims to be present in context") + identity, ok := IdentityFromContext(r.Context()) + require.True(t, ok, "Expected identity to be present in context") + require.NotNil(t, identity, "Expected identity to be non-nil") + + // Verify the identity fields + assert.Equal(t, "anonymous", identity.Subject) + assert.Equal(t, "Anonymous User", identity.Name) + assert.Equal(t, "anonymous@localhost", identity.Email) // Verify the anonymous claims - assert.Equal(t, "anonymous", claims["sub"]) - assert.Equal(t, "toolhive-local", claims["iss"]) - assert.Equal(t, "toolhive", claims["aud"]) - assert.Equal(t, "anonymous@localhost", claims["email"]) - assert.Equal(t, "Anonymous User", claims["name"]) + require.NotNil(t, identity.Claims) + assert.Equal(t, "anonymous", identity.Claims["sub"]) + assert.Equal(t, "toolhive-local", identity.Claims["iss"]) + assert.Equal(t, "toolhive", identity.Claims["aud"]) + assert.Equal(t, "anonymous@localhost", identity.Claims["email"]) + assert.Equal(t, "Anonymous User", identity.Claims["name"]) // Verify timestamps are reasonable now := time.Now().Unix() - exp, ok := claims["exp"].(int64) + exp, ok := identity.Claims["exp"].(int64) require.True(t, ok, "Expected exp to be present and be an int64") assert.Greater(t, exp, now, "Expected exp to be in the future") - iat, ok := claims["iat"].(int64) + iat, ok := identity.Claims["iat"].(int64) require.True(t, ok, "Expected iat to be present and be an int64") assert.LessOrEqual(t, iat, now+1, "Expected iat to be current time or earlier (with 1 second tolerance)") diff --git a/pkg/auth/context.go b/pkg/auth/context.go index 0c952d090..50ccb5f33 100644 --- a/pkg/auth/context.go +++ b/pkg/auth/context.go @@ -50,24 +50,6 @@ func IdentityFromContext(ctx context.Context) (*Identity, bool) { return identity, ok } -// GetClaimsFromContext retrieves the claims from Identity in the request context. -// This is a helper function for backward compatibility with code that expects MapClaims. -// New code should use IdentityFromContext and access the Claims field directly. -func GetClaimsFromContext(ctx context.Context) (jwt.MapClaims, bool) { - if ctx == nil { - return nil, false - } - - // Get Identity and return its Claims - if identity, ok := IdentityFromContext(ctx); ok && identity != nil { - if identity.Claims != nil { - return jwt.MapClaims(identity.Claims), true - } - } - - return nil, false -} - // claimsToIdentity converts JWT claims to Identity struct. // It requires the 'sub' claim per OIDC Core 1.0 spec § 5.1. // The original token can be provided for passthrough scenarios. diff --git a/pkg/auth/context_test.go b/pkg/auth/context_test.go index 2a6cf27cc..eef0b88b9 100644 --- a/pkg/auth/context_test.go +++ b/pkg/auth/context_test.go @@ -4,7 +4,6 @@ import ( "context" "testing" - "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -92,89 +91,6 @@ func TestIdentityContext_ExplicitNilValue(t *testing.T) { assert.Nil(t, identity, "expected nil identity") } -// TestGetClaimsFromContext_EdgeCases verifies backward-compatible claims retrieval edge cases. -func TestGetClaimsFromContext_EdgeCases(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - setupCtx func() context.Context - wantOk bool - checkFunc func(t *testing.T, claims jwt.MapClaims) - }{ - { - name: "identity_with_claims", - setupCtx: func() context.Context { - identity := &Identity{ - Subject: "user123", - Claims: map[string]any{ - "sub": "user123", - "org_id": "org456", - }, - } - return WithIdentity(context.Background(), identity) - }, - wantOk: true, - checkFunc: func(t *testing.T, claims jwt.MapClaims) { - t.Helper() - assert.Equal(t, "user123", claims["sub"]) - assert.Equal(t, "org456", claims["org_id"]) - }, - }, - { - name: "identity_with_nil_claims", - setupCtx: func() context.Context { - identity := &Identity{ - Subject: "user123", - Claims: nil, - } - return WithIdentity(context.Background(), identity) - }, - wantOk: false, - }, - { - name: "no_identity", - setupCtx: func() context.Context { - return context.Background() - }, - wantOk: false, - }, - { - name: "nil_context", - setupCtx: func() context.Context { - return nil - }, - wantOk: false, - }, - { - name: "explicitly_nil_identity", - setupCtx: func() context.Context { - return context.WithValue(context.Background(), IdentityContextKey{}, (*Identity)(nil)) - }, - wantOk: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctx := tt.setupCtx() - claims, ok := GetClaimsFromContext(ctx) - - assert.Equal(t, tt.wantOk, ok) - if tt.wantOk { - require.NotNil(t, claims) - if tt.checkFunc != nil { - tt.checkFunc(t, claims) - } - } else { - assert.Nil(t, claims) - } - }) - } -} - // TestIdentityContext_Overwrite verifies that storing a new identity replaces the old one. func TestIdentityContext_Overwrite(t *testing.T) { t.Parallel() diff --git a/pkg/auth/local_test.go b/pkg/auth/local_test.go index de43e58a4..908ca88d6 100644 --- a/pkg/auth/local_test.go +++ b/pkg/auth/local_test.go @@ -14,25 +14,32 @@ func TestLocalUserMiddleware(t *testing.T) { t.Parallel() username := "testuser" - // Create a test handler that checks for claims in the context + // Create a test handler that checks for identity in the context testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - claims, ok := GetClaimsFromContext(r.Context()) - require.True(t, ok, "Expected claims to be present in context") + identity, ok := IdentityFromContext(r.Context()) + require.True(t, ok, "Expected identity to be present in context") + require.NotNil(t, identity, "Expected identity to be non-nil") + + // Verify the identity fields + assert.Equal(t, username, identity.Subject) + assert.Equal(t, "Local User: "+username, identity.Name) + assert.Equal(t, username+"@localhost", identity.Email) // Verify the local user claims - assert.Equal(t, username, claims["sub"]) - assert.Equal(t, "toolhive-local", claims["iss"]) - assert.Equal(t, "toolhive", claims["aud"]) - assert.Equal(t, username+"@localhost", claims["email"]) - assert.Equal(t, "Local User: "+username, claims["name"]) + require.NotNil(t, identity.Claims) + assert.Equal(t, username, identity.Claims["sub"]) + assert.Equal(t, "toolhive-local", identity.Claims["iss"]) + assert.Equal(t, "toolhive", identity.Claims["aud"]) + assert.Equal(t, username+"@localhost", identity.Claims["email"]) + assert.Equal(t, "Local User: "+username, identity.Claims["name"]) // Verify timestamps are reasonable now := time.Now().Unix() - exp, ok := claims["exp"].(int64) + exp, ok := identity.Claims["exp"].(int64) require.True(t, ok, "Expected exp to be present and be an int64") assert.Greater(t, exp, now, "Expected exp to be in the future") - iat, ok := claims["iat"].(int64) + iat, ok := identity.Claims["iat"].(int64) require.True(t, ok, "Expected iat to be present and be an int64") assert.LessOrEqual(t, iat, now+1, "Expected iat to be current time or earlier (with 1 second tolerance)") @@ -63,11 +70,12 @@ func TestLocalUserMiddlewareWithDifferentUsernames(t *testing.T) { t.Run("username_"+username, func(t *testing.T) { t.Parallel() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - claims, ok := GetClaimsFromContext(r.Context()) - require.True(t, ok, "Expected claims to be present in context") + identity, ok := IdentityFromContext(r.Context()) + require.True(t, ok, "Expected identity to be present in context") + require.NotNil(t, identity, "Expected identity to be non-nil") - assert.Equal(t, username, claims["sub"]) - assert.Equal(t, username+"@localhost", claims["email"]) + assert.Equal(t, username, identity.Subject) + assert.Equal(t, username+"@localhost", identity.Email) w.WriteHeader(http.StatusOK) }) diff --git a/pkg/auth/token_test.go b/pkg/auth/token_test.go index d40780398..acb9944e2 100644 --- a/pkg/auth/token_test.go +++ b/pkg/auth/token_test.go @@ -242,17 +242,17 @@ func TestTokenValidatorMiddleware(t *testing.T) { // Create a test handler testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Get the claims from the context using the helper function - claims, ok := GetClaimsFromContext(r.Context()) - if !ok { - t.Errorf("Failed to get claims from context") - http.Error(w, "Failed to get claims from context", http.StatusInternalServerError) + // Get the identity from the context + identity, ok := IdentityFromContext(r.Context()) + if !ok || identity == nil { + t.Errorf("Failed to get identity from context") + http.Error(w, "Failed to get identity from context", http.StatusInternalServerError) return } // Write the claims as the response w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(claims); err != nil { + if err := json.NewEncoder(w).Encode(identity.Claims); err != nil { t.Errorf("Failed to encode claims: %v", err) http.Error(w, fmt.Sprintf("Failed to encode claims: %v", err), http.StatusInternalServerError) return diff --git a/pkg/auth/utils_test.go b/pkg/auth/utils_test.go index da58e728e..9ee518f58 100644 --- a/pkg/auth/utils_test.go +++ b/pkg/auth/utils_test.go @@ -7,7 +7,6 @@ import ( "net/http/httptest" "testing" - "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -105,97 +104,6 @@ func TestExtractBearerToken(t *testing.T) { } } -func TestGetClaimsFromContext(t *testing.T) { - t.Parallel() - // Test with claims in context - claims := jwt.MapClaims{ - "sub": "testuser", - "iss": "test-issuer", - "aud": "test-audience", - } - identity := &Identity{Subject: "testuser", Claims: claims} - ctx := WithIdentity(context.Background(), identity) - - retrievedClaims, ok := GetClaimsFromContext(ctx) - require.True(t, ok, "Expected to retrieve claims from context") - assert.Equal(t, "testuser", retrievedClaims["sub"]) - assert.Equal(t, "test-issuer", retrievedClaims["iss"]) - - // Test with no identity in context - emptyCtx := context.Background() - _, ok = GetClaimsFromContext(emptyCtx) - assert.False(t, ok, "Expected no claims to be found in empty context") - - // Test with identity that has nil claims - identityWithNilClaims := &Identity{Subject: "testuser", Claims: nil} - ctxWithNilClaims := WithIdentity(context.Background(), identityWithNilClaims) - _, ok = GetClaimsFromContext(ctxWithNilClaims) - assert.False(t, ok, "Expected no claims to be found when identity has nil claims") - - // Test with nil context - we intentionally pass nil to test the nil check - //nolint:staticcheck // SA1012: Testing nil context handling is intentional - _, ok = GetClaimsFromContext(nil) - assert.False(t, ok, "Expected no claims to be found with nil context") -} - -func TestGetClaimsFromContextWithDifferentClaimTypes(t *testing.T) { - t.Parallel() - testCases := []struct { - name string - claims jwt.MapClaims - expected map[string]interface{} - }{ - { - name: "string_claims", - claims: jwt.MapClaims{ - "sub": "user123", - "email": "user@example.com", - "name": "Test User", - }, - expected: map[string]interface{}{ - "sub": "user123", - "email": "user@example.com", - "name": "Test User", - }, - }, - { - name: "mixed_claims", - claims: jwt.MapClaims{ - "sub": "user123", - "exp": int64(1234567890), - "iat": int64(1234567800), - "admin": true, - }, - expected: map[string]interface{}{ - "sub": "user123", - "exp": int64(1234567890), - "iat": int64(1234567800), - "admin": true, - }, - }, - { - name: "empty_claims", - claims: jwt.MapClaims{}, - expected: map[string]interface{}{}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - identity := &Identity{Subject: "test-user", Claims: tc.claims} - ctx := WithIdentity(context.Background(), identity) - retrievedClaims, ok := GetClaimsFromContext(ctx) - - require.True(t, ok, "Expected to retrieve claims from context") - - for key, expectedValue := range tc.expected { - assert.Equal(t, expectedValue, retrievedClaims[key], "Expected %s to be %v, got %v", key, expectedValue, retrievedClaims[key]) - } - }) - } -} - func TestGetAuthenticationMiddleware(t *testing.T) { t.Parallel() // Initialize logger for testing @@ -210,9 +118,11 @@ func TestGetAuthenticationMiddleware(t *testing.T) { // Test that the middleware works by creating a test handler testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - claims, ok := GetClaimsFromContext(r.Context()) - require.True(t, ok, "Expected claims to be present in context") - assert.Equal(t, "toolhive-local", claims["iss"]) + identity, ok := IdentityFromContext(r.Context()) + require.True(t, ok, "Expected identity to be present in context") + require.NotNil(t, identity, "Expected identity to be non-nil") + require.NotNil(t, identity.Claims, "Expected claims to be present") + assert.Equal(t, "toolhive-local", identity.Claims["iss"]) w.WriteHeader(http.StatusOK) })