11package callback
22
33import (
4- "encoding/base64"
5- "encoding/json"
64 "fmt"
75 "log"
86 "net/http"
97 "net/url"
10- "os"
118 "strings"
129 "time"
1310
@@ -20,22 +17,48 @@ import (
2017type Store interface {
2118 StoreGrant (grant * types.Grant ) error
2219 StoreAuthCode (code , grantID , userID string ) error
20+ GetAuthRequest (key string ) (map [string ]interface {}, error )
21+ DeleteAuthRequest (key string ) error
2322}
2423
2524type Handler struct {
2625 db Store
2726 provider providers.Provider
2827 encryptionKey []byte
28+ clientID string
29+ clientSecret string
2930}
3031
31- func NewHandler (db Store , provider providers.Provider , encryptionKey []byte ) http.Handler {
32+ func NewHandler (db Store , provider providers.Provider , encryptionKey []byte , clientID , clientSecret string ) http.Handler {
3233 return & Handler {
3334 db : db ,
3435 provider : provider ,
3536 encryptionKey : encryptionKey ,
37+ clientID : clientID ,
38+ clientSecret : clientSecret ,
3639 }
3740}
3841
42+ // scopeContainsProfileOrEmail checks if the given scopes contain profile or email
43+ func (p * Handler ) scopeContainsProfileOrEmail (scopes []string ) bool {
44+ for _ , scope := range scopes {
45+ if scope == "profile" || scope == "email" {
46+ return true
47+ }
48+ }
49+ return false
50+ }
51+
52+ // getStringFromMap safely extracts a string value from a map[string]interface{}
53+ func getStringFromMap (data map [string ]interface {}, key string ) string {
54+ if value , ok := data [key ]; ok {
55+ if str , ok := value .(string ); ok {
56+ return str
57+ }
58+ }
59+ return ""
60+ }
61+
3962func (p * Handler ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
4063 // Handle OAuth callback from external providers
4164 code := r .URL .Query ().Get ("code" )
@@ -61,30 +84,39 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
6184 return
6285 }
6386
64- var authReq types. AuthRequest
65- stateData , err := base64 . URLEncoding . DecodeString (state )
87+ // Retrieve auth request data from database using state as key
88+ authData , err := p . db . GetAuthRequest (state )
6689 if err != nil {
6790 handlerutils .JSON (w , http .StatusBadRequest , types.OAuthError {
6891 Error : "invalid_request" ,
69- ErrorDescription : "Invalid state parameter" ,
92+ ErrorDescription : "Invalid or expired state parameter" ,
7093 })
7194 return
7295 }
73- if err := json .Unmarshal (stateData , & authReq ); err != nil {
74- handlerutils .JSON (w , http .StatusBadRequest , types.OAuthError {
75- Error : "invalid_request" ,
76- ErrorDescription : "Invalid state parameter" ,
77- })
78- return
96+
97+ // Convert auth data back to AuthRequest struct
98+ authReq := types.AuthRequest {
99+ ResponseType : getStringFromMap (authData , "response_type" ),
100+ ClientID : getStringFromMap (authData , "client_id" ),
101+ RedirectURI : getStringFromMap (authData , "redirect_uri" ),
102+ Scope : getStringFromMap (authData , "scope" ),
103+ State : getStringFromMap (authData , "state" ),
104+ CodeChallenge : getStringFromMap (authData , "code_challenge" ),
105+ CodeChallengeMethod : getStringFromMap (authData , "code_challenge_method" ),
79106 }
80107
108+ // Clean up the auth request data after successful retrieval
109+ defer func () {
110+ if err := p .db .DeleteAuthRequest (state ); err != nil {
111+ log .Printf ("Failed to delete auth request: %v" , err )
112+ }
113+ }()
114+
81115 // Get provider credentials
82- clientID := os .Getenv ("OAUTH_CLIENT_ID" )
83- clientSecret := os .Getenv ("OAUTH_CLIENT_SECRET" )
84116 redirectURI := fmt .Sprintf ("%s/callback" , handlerutils .GetBaseURL (r ))
85117
86118 // Exchange code for tokens
87- tokenInfo , err := p .provider .ExchangeCodeForToken (r .Context (), code , clientID , clientSecret , redirectURI )
119+ tokenInfo , err := p .provider .ExchangeCodeForToken (r .Context (), code , p . clientID , p . clientSecret , redirectURI )
88120 if err != nil {
89121 log .Printf ("Failed to exchange code for token: %v" , err )
90122 handlerutils .JSON (w , http .StatusBadRequest , types.OAuthError {
@@ -94,15 +126,22 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
94126 return
95127 }
96128
97- // Get user info from the provider
98- userInfo , err := p .provider .GetUserInfo (r .Context (), tokenInfo .AccessToken )
99- if err != nil {
100- log .Printf ("Failed to get user info: %v" , err )
101- handlerutils .JSON (w , http .StatusBadRequest , types.OAuthError {
102- Error : "invalid_grant" ,
103- ErrorDescription : "Failed to get user information" ,
104- })
105- return
129+ // Check if scope includes profile or email before getting user info
130+ scopes := strings .Fields (authReq .Scope )
131+ needsUserInfo := p .scopeContainsProfileOrEmail (scopes )
132+
133+ userInfo := & providers.UserInfo {}
134+ if needsUserInfo {
135+ // Get user info from the provider
136+ userInfo , err = p .provider .GetUserInfo (r .Context (), tokenInfo .AccessToken )
137+ if err != nil {
138+ log .Printf ("Failed to get user info: %v" , err )
139+ handlerutils .JSON (w , http .StatusBadRequest , types.OAuthError {
140+ Error : "invalid_grant" ,
141+ ErrorDescription : "Failed to get user information" ,
142+ })
143+ return
144+ }
106145 }
107146
108147 // Create a grant for this user
@@ -111,13 +150,17 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
111150
112151 // Prepare sensitive props data
113152 sensitiveProps := map [string ]interface {}{
114- "email" : userInfo .Email ,
115- "name" : userInfo .Name ,
116153 "access_token" : tokenInfo .AccessToken ,
117154 "refresh_token" : tokenInfo .RefreshToken ,
118155 "expires_at" : tokenInfo .ExpireAt ,
119156 }
120157
158+ // Only add user info if we have it
159+ if needsUserInfo {
160+ sensitiveProps ["email" ] = userInfo .Email
161+ sensitiveProps ["name" ] = userInfo .Name
162+ }
163+
121164 // Initialize props map
122165 props := make (map [string ]interface {})
123166
@@ -138,17 +181,13 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
138181 props ["algorithm" ] = encryptedProps .Algorithm
139182 props ["encrypted" ] = true
140183
141- // Add non-sensitive data
142- props ["user_id" ] = userInfo .ID
143-
144184 grant := & types.Grant {
145185 ID : grantID ,
146186 ClientID : authReq .ClientID ,
147187 UserID : userInfo .ID ,
148- Scope : strings . Fields ( authReq . Scope ) ,
188+ Scope : scopes ,
149189 Metadata : map [string ]interface {}{
150190 "provider" : p .provider ,
151- "label" : userInfo .Name ,
152191 },
153192 Props : props ,
154193 CreatedAt : now ,
0 commit comments