From 52695da3dcd0dfe572971b391882c1221768c3bf Mon Sep 17 00:00:00 2001 From: Prince Roshan Date: Thu, 30 Apr 2026 17:04:50 +0530 Subject: [PATCH 1/2] fix(access): exchange oidc tokens for platform jwt sessions --- services/api/main.go | 1 + services/api/platform_auth.go | 113 ++++++++++++++++ services/api/platform_auth_oidc_test.go | 167 ++++++++++++++++++++++++ services/api/platform_store.go | 2 +- services/ui/main.go | 88 ++++++------- services/ui/main_test.go | 84 +++++++++--- 6 files changed, 392 insertions(+), 63 deletions(-) create mode 100644 services/api/platform_auth_oidc_test.go diff --git a/services/api/main.go b/services/api/main.go index 29843b1..73d619a 100644 --- a/services/api/main.go +++ b/services/api/main.go @@ -219,6 +219,7 @@ func main() { }) }) mux.HandleFunc("/api/auth/login", server.handleLogin) + mux.HandleFunc("/api/auth/oidc", server.handleOIDCLogin) mux.HandleFunc("/api/auth/signup", server.handleSignup) mux.Handle("/api/events", server.auth(server.requireRole(roleAdmin, http.HandlerFunc(server.handleEvents)))) mux.Handle("/api/stats", server.auth(server.requireRole(roleAdmin, http.HandlerFunc(server.handleStats)))) diff --git a/services/api/platform_auth.go b/services/api/platform_auth.go index 8ae7913..caf9038 100644 --- a/services/api/platform_auth.go +++ b/services/api/platform_auth.go @@ -21,6 +21,8 @@ const ( ) var platformLoginAttempts = newAPILoginAttemptTracker(time.Now) +var oidcLoginHook func(context.Context, *apiServer, string) (platformUser, error) +var errOIDCUnauthorized = errors.New("oidc unauthorized") type apiLoginAttempt struct { failures int @@ -281,6 +283,117 @@ func (s *apiServer) handleLogin(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, map[string]any{"access_token": token, "token_type": "bearer", "expires_in": int(platformAccessTokenTTL.Seconds()), "user": u}) } +func (s *apiServer) handleOIDCLogin(w http.ResponseWriter, r *http.Request) { + if s.platform == nil { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "platform identity database not configured"}) + return + } + if r.Method != http.MethodPost { + w.Header().Set("allow", "POST") + writeJSON(w, http.StatusMethodNotAllowed, map[string]string{"error": "method_not_allowed"}) + return + } + if s.jwks == nil || strings.TrimSpace(s.oidcIssuer) == "" || strings.TrimSpace(s.oidcAudience) == "" { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc_not_configured"}) + return + } + + var req struct { + IDToken string `json:"id_token"` + } + r.Body = http.MaxBytesReader(w, r.Body, 8192) + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeBodyDecodeError(w, err) + return + } + idToken := strings.TrimSpace(req.IDToken) + if idToken == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "missing_id_token"}) + return + } + + var ( + u platformUser + err error + ) + if oidcLoginHook != nil { + u, err = oidcLoginHook(r.Context(), s, idToken) + } else { + u, err = s.resolveOIDCLoginUser(r.Context(), idToken) + } + if err != nil { + statusCode := http.StatusInternalServerError + auditStatus := "error" + auditResource := strings.ToLower(strings.TrimSpace(u.Email)) + if errors.Is(err, errOIDCUnauthorized) { + statusCode = http.StatusUnauthorized + auditStatus = "denied" + } + s.platform.WriteAudit(r.Context(), auditEvent{ + Action: "oidc_login", + Resource: auditResource, + Status: auditStatus, + Message: err.Error(), + ActorIP: requestIP(r), + }) + if statusCode == http.StatusUnauthorized { + writeJSON(w, statusCode, map[string]string{"error": "unauthorized"}) + return + } + writeJSON(w, statusCode, map[string]string{"error": "login_failed"}) + return + } + + token, err := s.platform.CreateAccessToken(u, platformAccessTokenTTL) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to issue token"}) + return + } + s.platform.WriteAudit(r.Context(), auditEvent{UserID: u.ID, Action: "oidc_login", Resource: "user", Namespace: u.Namespace, Status: "success", ActorIP: requestIP(r)}) + writeJSON(w, http.StatusOK, map[string]any{"access_token": token, "token_type": "bearer", "expires_in": int(platformAccessTokenTTL.Seconds()), "user": u}) +} + +func (s *apiServer) resolveOIDCLoginUser(ctx context.Context, idToken string) (platformUser, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost/api/auth/me", nil) + if err != nil { + return platformUser{}, err + } + req.Header.Set("authorization", "Bearer "+idToken) + + p, ok, err := s.authenticateRequest(req) + if err != nil { + return platformUser{}, err + } + if !ok || strings.TrimSpace(p.AuthType) != "oidc_jwt" { + return platformUser{}, fmt.Errorf("%w: token authentication failed", errOIDCUnauthorized) + } + if strings.TrimSpace(p.Subject) == "" { + return platformUser{}, fmt.Errorf("%w: token missing user identity", errOIDCUnauthorized) + } + if strings.TrimSpace(p.Email) == "" { + return platformUser{}, fmt.Errorf("%w: token missing email", errOIDCUnauthorized) + } + role := strings.TrimSpace(p.Role) + if role == "" { + role = roleUser + } + + u, ok, err := s.platform.GetUser(ctx, p.Subject) + if err != nil { + return platformUser{}, err + } + if ok { + return u, nil + } + + return platformUser{ + ID: strings.TrimSpace(p.Subject), + Email: strings.TrimSpace(p.Email), + Role: role, + Namespace: strings.TrimSpace(p.Namespace), + }, nil +} + func requestIP(r *http.Request) string { if xff := strings.TrimSpace(r.Header.Get("x-forwarded-for")); xff != "" { return strings.TrimSpace(strings.Split(xff, ",")[0]) diff --git a/services/api/platform_auth_oidc_test.go b/services/api/platform_auth_oidc_test.go new file mode 100644 index 0000000..09148b6 --- /dev/null +++ b/services/api/platform_auth_oidc_test.go @@ -0,0 +1,167 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/MicahParks/keyfunc" + "github.com/golang-jwt/jwt/v4" +) + +func TestHandleOIDCLoginSuccess(t *testing.T) { + previousHook := oidcLoginHook + oidcLoginHook = func(_ context.Context, _ *apiServer, token string) (platformUser, error) { + if token != "google-id-token" { + t.Fatalf("id token = %q", token) + } + return platformUser{ + ID: "user-123", + Email: "user@example.com", + Role: roleUser, + Namespace: "user-1", + }, nil + } + defer func() { oidcLoginHook = previousHook }() + + server := &apiServer{ + platform: &platformStore{jwtSecret: []byte("test-secret")}, + jwks: &keyfunc.JWKS{}, + oidcIssuer: "https://issuer.example", + oidcAudience: "client-id", + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/oidc", strings.NewReader(`{"id_token":"google-id-token"}`)) + server.handleOIDCLogin(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var payload struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + User platformUser `json:"user"` + } + if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if payload.AccessToken == "" { + t.Fatal("expected access token") + } + if strings.Contains(payload.AccessToken, "google-id-token") { + t.Fatalf("platform token leaked raw id token: %q", payload.AccessToken) + } + if payload.TokenType != "bearer" { + t.Fatalf("token_type = %q, want bearer", payload.TokenType) + } + if payload.ExpiresIn != int(platformAccessTokenTTL.Seconds()) { + t.Fatalf("expires_in = %d, want %d", payload.ExpiresIn, int(platformAccessTokenTTL.Seconds())) + } + if payload.User.ID != "user-123" || payload.User.Email != "user@example.com" { + t.Fatalf("user payload = %+v", payload.User) + } + + parsed, err := jwt.Parse(payload.AccessToken, func(t *jwt.Token) (any, error) { + return []byte("test-secret"), nil + }) + if err != nil || !parsed.Valid { + t.Fatalf("platform token parse failed: %v", err) + } + claims, ok := parsed.Claims.(jwt.MapClaims) + if !ok { + t.Fatal("missing jwt claims") + } + if got := strings.TrimSpace(fmt.Sprint(claims["sub"])); got != "user-123" { + t.Fatalf("subject claim = %q, want user-123", got) + } +} + +func TestHandleOIDCLoginRequiresPlatformStore(t *testing.T) { + server := &apiServer{ + jwks: &keyfunc.JWKS{}, + oidcIssuer: "https://issuer.example", + oidcAudience: "client-id", + } + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/oidc", strings.NewReader(`{"id_token":"google-id-token"}`)) + server.handleOIDCLogin(rec, req) + if rec.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusServiceUnavailable) + } +} + +func TestHandleOIDCLoginRequiresOIDCConfig(t *testing.T) { + server := &apiServer{ + platform: &platformStore{jwtSecret: []byte("test-secret")}, + } + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/oidc", strings.NewReader(`{"id_token":"google-id-token"}`)) + server.handleOIDCLogin(rec, req) + if rec.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusServiceUnavailable) + } +} + +func TestHandleOIDCLoginMissingToken(t *testing.T) { + server := &apiServer{ + platform: &platformStore{jwtSecret: []byte("test-secret")}, + jwks: &keyfunc.JWKS{}, + oidcIssuer: "https://issuer.example", + oidcAudience: "client-id", + } + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/oidc", strings.NewReader(`{}`)) + server.handleOIDCLogin(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestHandleOIDCLoginInternalError(t *testing.T) { + previousHook := oidcLoginHook + oidcLoginHook = func(_ context.Context, _ *apiServer, _ string) (platformUser, error) { + return platformUser{Email: "user@example.com"}, errors.New("failed") + } + defer func() { oidcLoginHook = previousHook }() + + server := &apiServer{ + platform: &platformStore{jwtSecret: []byte("test-secret")}, + jwks: &keyfunc.JWKS{}, + oidcIssuer: "https://issuer.example", + oidcAudience: "client-id", + } + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/oidc", strings.NewReader(`{"id_token":"google-id-token"}`)) + server.handleOIDCLogin(rec, req) + if rec.Code != http.StatusInternalServerError { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } +} + +func TestHandleOIDCLoginInvalidOIDCToken(t *testing.T) { + previousHook := oidcLoginHook + oidcLoginHook = func(_ context.Context, _ *apiServer, _ string) (platformUser, error) { + return platformUser{Email: "user@example.com"}, errOIDCUnauthorized + } + defer func() { oidcLoginHook = previousHook }() + + server := &apiServer{ + platform: &platformStore{jwtSecret: []byte("test-secret")}, + jwks: &keyfunc.JWKS{}, + oidcIssuer: "https://issuer.example", + oidcAudience: "client-id", + } + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/oidc", strings.NewReader(`{"id_token":"google-id-token"}`)) + server.handleOIDCLogin(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } +} diff --git a/services/api/platform_store.go b/services/api/platform_store.go index bfde655..404d27c 100644 --- a/services/api/platform_store.go +++ b/services/api/platform_store.go @@ -597,7 +597,7 @@ func (s *platformStore) ListNamespaces(ctx context.Context) ([]map[string]any, e } func (s *platformStore) WriteAudit(ctx context.Context, ev auditEvent) { - if s == nil { + if s == nil || s.db == nil { return } _, _ = s.db.ExecContext(ctx, `INSERT INTO audit_logs (user_id,action,resource,namespace,status,message,actor_ip,request_id) VALUES (NULLIF($1,'')::uuid,$2,$3,$4,$5,$6,$7,$8)`, diff --git a/services/ui/main.go b/services/ui/main.go index cab3c9b..a8a4643 100644 --- a/services/ui/main.go +++ b/services/ui/main.go @@ -70,6 +70,7 @@ type uiSession struct { UpstreamAPIKey string } +// uiSessionStore is intentionally in-memory only; sessions are cleared on UI restart. type uiSessionStore struct { mu sync.Mutex sessions map[string]uiSession @@ -93,7 +94,7 @@ type loginClientState struct { var ( loginAttempts = newLoginAttemptTracker(time.Now) sessions = newUISessionStore(time.Now) - oidcVerifyHook func(context.Context, string, string) (sessionPrincipal, error) + oidcLoginHook func(context.Context, string, string) (sessionPrincipal, string, time.Time, error) authHTTPClient = &http.Client{Timeout: 10 * time.Second} ) @@ -364,12 +365,14 @@ func handleLogin(apiKey, upstreamAPIKey, apiUpstream string, store *uiSessionSto } else if idToken != "" { var ( p sessionPrincipal + token string + expiresAt time.Time verifyErr error ) - if oidcVerifyHook != nil { - p, verifyErr = oidcVerifyHook(r.Context(), apiUpstream, idToken) + if oidcLoginHook != nil { + p, token, expiresAt, verifyErr = oidcLoginHook(r.Context(), apiUpstream, idToken) } else { - p, verifyErr = verifyOIDCTokenWithAPI(r.Context(), apiUpstream, idToken) + p, token, expiresAt, verifyErr = loginOIDCWithAPI(r.Context(), apiUpstream, idToken) } if verifyErr != nil { failures := loginAttempts.recordFailure(clientID) @@ -382,8 +385,8 @@ func handleLogin(apiKey, upstreamAPIKey, apiUpstream string, store *uiSessionSto } sess, err = store.createSession(r.Context(), uiSession{ Principal: p, - UpstreamAuthHeader: "Bearer " + idToken, - ExpiresAt: idTokenExpiry(idToken), + UpstreamAuthHeader: "Bearer " + token, + ExpiresAt: expiresAt, }) } else { if apiKey == "" { @@ -477,69 +480,64 @@ func loginPasswordWithAPI(ctx context.Context, apiUpstream, email, password stri }, payload.AccessToken, nil } -func idTokenExpiry(idToken string) time.Time { - parts := strings.Split(strings.TrimSpace(idToken), ".") - if len(parts) < 2 { - return time.Time{} - } - payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) +func loginOIDCWithAPI(ctx context.Context, apiUpstream, idToken string) (sessionPrincipal, string, time.Time, error) { + oidcURL, err := apiUpstreamURL(apiUpstream, "api", "auth", "oidc") if err != nil { - return time.Time{} - } - var claims struct { - Exp int64 `json:"exp"` - } - if err := json.Unmarshal(payloadBytes, &claims); err != nil { - return time.Time{} - } - if claims.Exp <= 0 { - return time.Time{} - } - return time.Unix(claims.Exp, 0).UTC() -} - -func verifyOIDCTokenWithAPI(ctx context.Context, apiUpstream, idToken string) (sessionPrincipal, error) { - meURL, err := apiUpstreamURL(apiUpstream, "api", "auth", "me") - if err != nil { - return sessionPrincipal{}, err + return sessionPrincipal{}, "", time.Time{}, err } ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, meURL, nil) + body, err := json.Marshal(map[string]string{"id_token": idToken}) if err != nil { - return sessionPrincipal{}, err + return sessionPrincipal{}, "", time.Time{}, err } - req.Header.Set("authorization", "Bearer "+idToken) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, oidcURL, strings.NewReader(string(body))) + if err != nil { + return sessionPrincipal{}, "", time.Time{}, err + } + req.Header.Set("content-type", "application/json") resp, err := authHTTPClient.Do(req) if err != nil { - return sessionPrincipal{}, err + return sessionPrincipal{}, "", time.Time{}, err } defer drainAndClose(resp.Body) if resp.StatusCode != http.StatusOK { _, _ = io.Copy(io.Discard, resp.Body) - return sessionPrincipal{}, fmt.Errorf("auth check failed: status %d", resp.StatusCode) + return sessionPrincipal{}, "", time.Time{}, fmt.Errorf("oidc login failed: status %d", resp.StatusCode) } var payload struct { - Authenticated bool `json:"authenticated"` - Principal sessionPrincipal `json:"principal"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + User struct { + ID string `json:"id"` + Email string `json:"email"` + Role string `json:"role"` + } `json:"user"` } if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { - return sessionPrincipal{}, err + return sessionPrincipal{}, "", time.Time{}, err } - if !payload.Authenticated { - return sessionPrincipal{}, errors.New("not authenticated") + if strings.TrimSpace(payload.AccessToken) == "" { + return sessionPrincipal{}, "", time.Time{}, errors.New("missing access token") } - if strings.TrimSpace(payload.Principal.Role) == "" { - payload.Principal.Role = "user" + role := strings.TrimSpace(payload.User.Role) + if role == "" { + role = "user" } - if payload.Principal.AuthType == "" { - payload.Principal.AuthType = "oidc_jwt" + var expiresAt time.Time + if payload.ExpiresIn > 0 { + expiresAt = time.Now().UTC().Add(time.Duration(payload.ExpiresIn) * time.Second) } - return payload.Principal, nil + return sessionPrincipal{ + Role: role, + Subject: strings.TrimSpace(payload.User.ID), + Email: strings.TrimSpace(payload.User.Email), + AuthType: "platform_jwt", + }, strings.TrimSpace(payload.AccessToken), expiresAt, nil } func apiUpstreamURL(apiUpstream string, parts ...string) (string, error) { diff --git a/services/ui/main_test.go b/services/ui/main_test.go index 458f523..c2c8875 100644 --- a/services/ui/main_test.go +++ b/services/ui/main_test.go @@ -2,8 +2,6 @@ package main import ( "context" - "encoding/base64" - "fmt" "io" "net/http" "net/http/httptest" @@ -161,8 +159,9 @@ func TestAPIProxyAllowsPublicRuntimeServers(t *testing.T) { } func TestHandleLoginWithOIDCToken(t *testing.T) { - previousHook := oidcVerifyHook - oidcVerifyHook = func(_ context.Context, upstream, token string) (sessionPrincipal, error) { + now := time.Now().UTC() + previousHook := oidcLoginHook + oidcLoginHook = func(_ context.Context, upstream, token string) (sessionPrincipal, string, time.Time, error) { if upstream != "http://api.example" { t.Fatalf("upstream = %q, want http://api.example", upstream) } @@ -172,10 +171,10 @@ func TestHandleLoginWithOIDCToken(t *testing.T) { return sessionPrincipal{ Role: "user", Subject: "user-123", - AuthType: "oidc_jwt", - }, nil + AuthType: "platform_jwt", + }, "platform-token", now.Add(15 * time.Minute), nil } - defer func() { oidcVerifyHook = previousHook }() + defer func() { oidcLoginHook = previousHook }() store := newUISessionStore(time.Now) login := httptest.NewRecorder() @@ -187,11 +186,21 @@ func TestHandleLoginWithOIDCToken(t *testing.T) { if len(cookies) != 1 { t.Fatalf("cookies = %d, want 1", len(cookies)) } + sess, ok := store.get(cookies[0].Value) + if !ok { + t.Fatal("expected persisted session") + } + if got := sess.UpstreamAuthHeader; got != "Bearer platform-token" { + t.Fatalf("stored upstream authorization = %q", got) + } + if got := sess.UpstreamAuthHeader; strings.Contains(got, "id-token") { + t.Fatalf("stored upstream authorization leaked raw id token: %q", got) + } upstreamCalled := false transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { upstreamCalled = true - if got := r.Header.Get("authorization"); got != "Bearer id-token" { + if got := r.Header.Get("authorization"); got != "Bearer platform-token" { t.Fatalf("authorization forwarded = %q", got) } if got := r.Header.Get("x-api-key"); got != "" { @@ -216,8 +225,8 @@ func TestHandleLoginWithOIDCToken(t *testing.T) { func TestHandleLoginWithOIDCTokenCapsSessionToTokenExpiry(t *testing.T) { now := time.Now().UTC().Add(2 * time.Minute) - previousHook := oidcVerifyHook - oidcVerifyHook = func(_ context.Context, upstream, token string) (sessionPrincipal, error) { + previousHook := oidcLoginHook + oidcLoginHook = func(_ context.Context, upstream, token string) (sessionPrincipal, string, time.Time, error) { if upstream != "http://api.example" { t.Fatalf("upstream = %q, want http://api.example", upstream) } @@ -227,17 +236,14 @@ func TestHandleLoginWithOIDCTokenCapsSessionToTokenExpiry(t *testing.T) { return sessionPrincipal{ Role: "user", Subject: "user-123", - AuthType: "oidc_jwt", - }, nil + AuthType: "platform_jwt", + }, "platform-token", now.Add(30 * time.Minute), nil } - defer func() { oidcVerifyHook = previousHook }() + defer func() { oidcLoginHook = previousHook }() - exp := now.Add(30 * time.Minute) - payload := fmt.Sprintf(`{"exp":%d}`, exp.Unix()) - idToken := "eyJhbGciOiJub25lIn0." + base64.RawURLEncoding.EncodeToString([]byte(payload)) + ".sig" store := newUISessionStore(func() time.Time { return now }) login := httptest.NewRecorder() - handleLogin("", "api-secret", "http://api.example", store).ServeHTTP(login, httptest.NewRequest(http.MethodPost, "/auth/login", strings.NewReader(`{"id_token":"`+idToken+`"}`))) + handleLogin("", "api-secret", "http://api.example", store).ServeHTTP(login, httptest.NewRequest(http.MethodPost, "/auth/login", strings.NewReader(`{"id_token":"id-token"}`))) if login.Code != http.StatusOK { t.Fatalf("login status = %d, want %d; body=%s", login.Code, http.StatusOK, login.Body.String()) } @@ -252,11 +258,55 @@ func TestHandleLoginWithOIDCTokenCapsSessionToTokenExpiry(t *testing.T) { if !ok { t.Fatal("expected persisted session") } + exp := now.Add(30 * time.Minute) if sess.ExpiresAt.After(exp.Add(time.Second)) || sess.ExpiresAt.Before(exp.Add(-1*time.Second)) { t.Fatalf("session expiry = %s, want %s", sess.ExpiresAt.Format(time.RFC3339), exp.Format(time.RFC3339)) } } +func TestUISessionStateIsEphemeralAcrossStoreRestart(t *testing.T) { + now := time.Now().UTC() + previousHook := oidcLoginHook + oidcLoginHook = func(_ context.Context, upstream, token string) (sessionPrincipal, string, time.Time, error) { + if upstream != "http://api.example" { + t.Fatalf("upstream = %q, want http://api.example", upstream) + } + if token == "" { + t.Fatal("token should not be empty") + } + return sessionPrincipal{Role: "user", Subject: "user-123", AuthType: "platform_jwt"}, "platform-token", now.Add(10 * time.Minute), nil + } + defer func() { oidcLoginHook = previousHook }() + + originalStore := newUISessionStore(func() time.Time { return now }) + login := httptest.NewRecorder() + handleLogin("", "api-secret", "http://api.example", originalStore).ServeHTTP(login, httptest.NewRequest(http.MethodPost, "/auth/login", strings.NewReader(`{"id_token":"id-token"}`))) + if login.Code != http.StatusOK { + t.Fatalf("login status = %d, want %d; body=%s", login.Code, http.StatusOK, login.Body.String()) + } + cookies := login.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("cookies = %d, want 1", len(cookies)) + } + + beforeRestart := httptest.NewRecorder() + beforeReq := httptest.NewRequest(http.MethodGet, "/auth/status", nil) + beforeReq.AddCookie(cookies[0]) + handleStatus(originalStore).ServeHTTP(beforeRestart, beforeReq) + if !strings.Contains(beforeRestart.Body.String(), `"authenticated":true`) { + t.Fatalf("status before restart = %s", beforeRestart.Body.String()) + } + + restartedStore := newUISessionStore(func() time.Time { return now }) + afterRestart := httptest.NewRecorder() + afterReq := httptest.NewRequest(http.MethodGet, "/auth/status", nil) + afterReq.AddCookie(cookies[0]) + handleStatus(restartedStore).ServeHTTP(afterRestart, afterReq) + if !strings.Contains(afterRestart.Body.String(), `"authenticated":false`) { + t.Fatalf("status after restart = %s", afterRestart.Body.String()) + } +} + func TestHandleLoginWithPassword(t *testing.T) { previousHook := passwordLoginHook passwordLoginHook = func(_ context.Context, upstream, email, password string) (sessionPrincipal, string, error) { From 071a43b18cb8b35a0e0643c7262be08b5287558d Mon Sep 17 00:00:00 2001 From: Prince Roshan Date: Thu, 30 Apr 2026 17:13:53 +0000 Subject: [PATCH 2/2] fix(access): restore oidc login fallback --- services/api/platform_auth.go | 49 +++++++------ services/api/platform_auth_oidc_test.go | 26 +++++++ services/ui/main.go | 93 ++++++++++++++++++++++++- services/ui/main_test.go | 63 +++++++++++++++++ 4 files changed, 206 insertions(+), 25 deletions(-) diff --git a/services/api/platform_auth.go b/services/api/platform_auth.go index caf9038..28c72f2 100644 --- a/services/api/platform_auth.go +++ b/services/api/platform_auth.go @@ -12,6 +12,8 @@ import ( "strings" "sync" "time" + + "github.com/golang-jwt/jwt/v4" ) const platformAccessTokenTTL = 15 * time.Minute @@ -325,6 +327,9 @@ func (s *apiServer) handleOIDCLogin(w http.ResponseWriter, r *http.Request) { statusCode := http.StatusInternalServerError auditStatus := "error" auditResource := strings.ToLower(strings.TrimSpace(u.Email)) + if auditResource == "" { + auditResource = oidcAuditResource(idToken) + } if errors.Is(err, errOIDCUnauthorized) { statusCode = http.StatusUnauthorized auditStatus = "denied" @@ -354,7 +359,7 @@ func (s *apiServer) handleOIDCLogin(w http.ResponseWriter, r *http.Request) { } func (s *apiServer) resolveOIDCLoginUser(ctx context.Context, idToken string) (platformUser, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost/api/auth/me", nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://oidc.internal/verify", nil) if err != nil { return platformUser{}, err } @@ -364,36 +369,34 @@ func (s *apiServer) resolveOIDCLoginUser(ctx context.Context, idToken string) (p if err != nil { return platformUser{}, err } - if !ok || strings.TrimSpace(p.AuthType) != "oidc_jwt" { + if !ok || p.AuthType != "oidc_jwt" { return platformUser{}, fmt.Errorf("%w: token authentication failed", errOIDCUnauthorized) } - if strings.TrimSpace(p.Subject) == "" { - return platformUser{}, fmt.Errorf("%w: token missing user identity", errOIDCUnauthorized) - } - if strings.TrimSpace(p.Email) == "" { - return platformUser{}, fmt.Errorf("%w: token missing email", errOIDCUnauthorized) - } - role := strings.TrimSpace(p.Role) - if role == "" { - role = roleUser - } - - u, ok, err := s.platform.GetUser(ctx, p.Subject) - if err != nil { - return platformUser{}, err - } - if ok { - return u, nil + if p.Subject == "" || p.Email == "" { + return platformUser{}, fmt.Errorf("%w: token missing identity", errOIDCUnauthorized) } return platformUser{ - ID: strings.TrimSpace(p.Subject), - Email: strings.TrimSpace(p.Email), - Role: role, - Namespace: strings.TrimSpace(p.Namespace), + ID: p.Subject, + Email: p.Email, + Role: p.Role, + Namespace: p.Namespace, }, nil } +func oidcAuditResource(idToken string) string { + claims := jwt.MapClaims{} + if _, _, err := jwt.NewParser().ParseUnverified(strings.TrimSpace(idToken), claims); err != nil { + return "unknown" + } + email, _ := claims["email"].(string) + email = strings.ToLower(strings.TrimSpace(email)) + if email == "" { + return "unknown" + } + return email +} + func requestIP(r *http.Request) string { if xff := strings.TrimSpace(r.Header.Get("x-forwarded-for")); xff != "" { return strings.TrimSpace(strings.Split(xff, ",")[0]) diff --git a/services/api/platform_auth_oidc_test.go b/services/api/platform_auth_oidc_test.go index 09148b6..efa3398 100644 --- a/services/api/platform_auth_oidc_test.go +++ b/services/api/platform_auth_oidc_test.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -165,3 +166,28 @@ func TestHandleOIDCLoginInvalidOIDCToken(t *testing.T) { t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized) } } + +func TestOIDCAuditResourceUsesUnverifiedEmailClaim(t *testing.T) { + idToken := unsignedTestJWT(`{"email":"USER@example.COM"}`) + if got := oidcAuditResource(idToken); got != "user@example.com" { + t.Fatalf("oidcAuditResource() = %q, want user@example.com", got) + } +} + +func TestOIDCAuditResourceFallsBackToUnknown(t *testing.T) { + for _, idToken := range []string{ + "", + "not-a-jwt", + unsignedTestJWT(`{"sub":"user-123"}`), + } { + if got := oidcAuditResource(idToken); got != "unknown" { + t.Fatalf("oidcAuditResource(%q) = %q, want unknown", idToken, got) + } + } +} + +func unsignedTestJWT(payload string) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none"}`)) + body := base64.RawURLEncoding.EncodeToString([]byte(payload)) + return header + "." + body + ".sig" +} diff --git a/services/ui/main.go b/services/ui/main.go index a8a4643..c241e30 100644 --- a/services/ui/main.go +++ b/services/ui/main.go @@ -372,7 +372,7 @@ func handleLogin(apiKey, upstreamAPIKey, apiUpstream string, store *uiSessionSto if oidcLoginHook != nil { p, token, expiresAt, verifyErr = oidcLoginHook(r.Context(), apiUpstream, idToken) } else { - p, token, expiresAt, verifyErr = loginOIDCWithAPI(r.Context(), apiUpstream, idToken) + p, token, expiresAt, verifyErr = loginOIDCSession(r.Context(), apiUpstream, idToken) } if verifyErr != nil { failures := loginAttempts.recordFailure(clientID) @@ -425,6 +425,22 @@ func handleLogin(apiKey, upstreamAPIKey, apiUpstream string, store *uiSessionSto } } +func loginOIDCSession(ctx context.Context, apiUpstream, idToken string) (sessionPrincipal, string, time.Time, error) { + p, token, expiresAt, err := loginOIDCWithAPI(ctx, apiUpstream, idToken) + if err == nil { + return p, token, expiresAt, nil + } + var statusErr *oidcLoginStatusError + if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusServiceUnavailable { + return sessionPrincipal{}, "", time.Time{}, err + } + p, verifyErr := verifyOIDCTokenWithAPI(ctx, apiUpstream, idToken) + if verifyErr != nil { + return sessionPrincipal{}, "", time.Time{}, verifyErr + } + return p, idToken, idTokenExpiry(idToken), nil +} + func loginPasswordWithAPI(ctx context.Context, apiUpstream, email, password string) (sessionPrincipal, string, error) { loginURL, err := apiUpstreamURL(apiUpstream, "api", "auth", "login") if err != nil { @@ -480,6 +496,35 @@ func loginPasswordWithAPI(ctx context.Context, apiUpstream, email, password stri }, payload.AccessToken, nil } +func idTokenExpiry(idToken string) time.Time { + parts := strings.Split(strings.TrimSpace(idToken), ".") + if len(parts) < 2 { + return time.Time{} + } + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return time.Time{} + } + var claims struct { + Exp int64 `json:"exp"` + } + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + return time.Time{} + } + if claims.Exp <= 0 { + return time.Time{} + } + return time.Unix(claims.Exp, 0).UTC() +} + +type oidcLoginStatusError struct { + StatusCode int +} + +func (e *oidcLoginStatusError) Error() string { + return fmt.Sprintf("oidc login failed: status %d", e.StatusCode) +} + func loginOIDCWithAPI(ctx context.Context, apiUpstream, idToken string) (sessionPrincipal, string, time.Time, error) { oidcURL, err := apiUpstreamURL(apiUpstream, "api", "auth", "oidc") if err != nil { @@ -506,7 +551,7 @@ func loginOIDCWithAPI(ctx context.Context, apiUpstream, idToken string) (session if resp.StatusCode != http.StatusOK { _, _ = io.Copy(io.Discard, resp.Body) - return sessionPrincipal{}, "", time.Time{}, fmt.Errorf("oidc login failed: status %d", resp.StatusCode) + return sessionPrincipal{}, "", time.Time{}, &oidcLoginStatusError{StatusCode: resp.StatusCode} } var payload struct { @@ -540,6 +585,50 @@ func loginOIDCWithAPI(ctx context.Context, apiUpstream, idToken string) (session }, strings.TrimSpace(payload.AccessToken), expiresAt, nil } +func verifyOIDCTokenWithAPI(ctx context.Context, apiUpstream, idToken string) (sessionPrincipal, error) { + meURL, err := apiUpstreamURL(apiUpstream, "api", "auth", "me") + if err != nil { + return sessionPrincipal{}, err + } + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, meURL, nil) + if err != nil { + return sessionPrincipal{}, err + } + req.Header.Set("authorization", "Bearer "+idToken) + + resp, err := authHTTPClient.Do(req) + if err != nil { + return sessionPrincipal{}, err + } + defer drainAndClose(resp.Body) + + if resp.StatusCode != http.StatusOK { + _, _ = io.Copy(io.Discard, resp.Body) + return sessionPrincipal{}, fmt.Errorf("auth check failed: status %d", resp.StatusCode) + } + + var payload struct { + Authenticated bool `json:"authenticated"` + Principal sessionPrincipal `json:"principal"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return sessionPrincipal{}, err + } + if !payload.Authenticated { + return sessionPrincipal{}, errors.New("not authenticated") + } + if strings.TrimSpace(payload.Principal.Role) == "" { + payload.Principal.Role = "user" + } + if payload.Principal.AuthType == "" { + payload.Principal.AuthType = "oidc_jwt" + } + return payload.Principal, nil +} + func apiUpstreamURL(apiUpstream string, parts ...string) (string, error) { base := strings.TrimSpace(apiUpstream) if base == "" { diff --git a/services/ui/main_test.go b/services/ui/main_test.go index c2c8875..18866bb 100644 --- a/services/ui/main_test.go +++ b/services/ui/main_test.go @@ -2,6 +2,8 @@ package main import ( "context" + "encoding/base64" + "fmt" "io" "net/http" "net/http/httptest" @@ -264,6 +266,67 @@ func TestHandleLoginWithOIDCTokenCapsSessionToTokenExpiry(t *testing.T) { } } +func TestLoginOIDCSessionFallsBackToTokenVerificationWhenPlatformStoreUnavailable(t *testing.T) { + now := time.Now().UTC().Add(2 * time.Minute) + exp := now.Add(30 * time.Minute) + payload := fmt.Sprintf(`{"exp":%d}`, exp.Unix()) + idToken := "eyJhbGciOiJub25lIn0." + base64.RawURLEncoding.EncodeToString([]byte(payload)) + ".sig" + + var paths []string + previousClient := authHTTPClient + authHTTPClient = &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + paths = append(paths, r.URL.Path) + switch r.URL.Path { + case "/api/auth/oidc": + if r.Method != http.MethodPost { + t.Fatalf("oidc method = %s, want POST", r.Method) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read oidc body: %v", err) + } + if !strings.Contains(string(body), idToken) { + t.Fatalf("oidc body = %s, want id token", string(body)) + } + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{"content-type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"error":"platform identity database not configured"}`)), + }, nil + case "/api/auth/me": + if got := r.Header.Get("authorization"); got != "Bearer "+idToken { + t.Fatalf("fallback authorization = %q, want bearer id token", got) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"content-type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"authenticated":true,"principal":{"role":"user","subject":"user-123","email":"user@example.com"}}`)), + }, nil + default: + t.Fatalf("unexpected request path %q", r.URL.Path) + } + return nil, nil + })} + t.Cleanup(func() { authHTTPClient = previousClient }) + + p, token, expiresAt, err := loginOIDCSession(context.Background(), "http://api.example", idToken) + if err != nil { + t.Fatalf("loginOIDCSession() error = %v", err) + } + if token != idToken { + t.Fatalf("token = %q, want original id token", token) + } + if p.AuthType != "oidc_jwt" || p.Subject != "user-123" || p.Email != "user@example.com" { + t.Fatalf("principal = %+v", p) + } + if expiresAt.After(exp.Add(time.Second)) || expiresAt.Before(exp.Add(-time.Second)) { + t.Fatalf("session expiry = %s, want %s", expiresAt.Format(time.RFC3339), exp.Format(time.RFC3339)) + } + if len(paths) != 2 || paths[0] != "/api/auth/oidc" || paths[1] != "/api/auth/me" { + t.Fatalf("request paths = %v, want oidc exchange then auth/me fallback", paths) + } +} + func TestUISessionStateIsEphemeralAcrossStoreRestart(t *testing.T) { now := time.Now().UTC() previousHook := oidcLoginHook