Skip to content

Commit 538437c

Browse files
Merge pull request #280 from supertokens/jwt-rework/fix-tests
Update integration server and fix for tests
2 parents 2b05770 + 75b02b9 commit 538437c

File tree

2 files changed

+177
-6
lines changed

2 files changed

+177
-6
lines changed

recipe/session/recipeImplementation.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ
176176
}
177177

178178
session := response.Session
179-
frontToken := BuildFrontToken(session.UserID, session.ExpiryTime, responseToken.Payload)
179+
frontToken := BuildFrontToken(session.UserID, response.AccessToken.Expiry, responseToken.Payload)
180180

181181
sessionContainerInput := makeSessionContainerInput(response.AccessToken.Token, session.Handle, session.UserID, responseToken.Payload, result, frontToken, response.AntiCsrfToken, nil, &response.RefreshToken, true)
182182
sessionContainer := newSessionContainer(config, &sessionContainerInput)

test/frontendIntegration/main.go

Lines changed: 176 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package main
1818

1919
import (
20+
"encoding/base64"
2021
"encoding/json"
2122
"fmt"
2223
"io/ioutil"
@@ -58,10 +59,24 @@ func maxVersion(version1 string, version2 string) string {
5859
return version2
5960
}
6061

62+
func isProtectedProp(prop string) bool {
63+
protectedProps := []string{
64+
"sub",
65+
"iat",
66+
"exp",
67+
"sessionHandle",
68+
"parentRefreshTokenHash1",
69+
"refreshTokenHash1",
70+
"antiCsrfToken",
71+
}
72+
73+
return supertokens.DoesSliceContainString(prop, protectedProps)
74+
}
75+
6176
var routes *http.Handler
6277

6378
func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) {
64-
if maxVersion(supertokens.VERSION, "0.3.1") == supertokens.VERSION && enableJWT {
79+
if maxVersion(supertokens.VERSION, "0.12.0") == supertokens.VERSION && enableJWT {
6580
port := "8080"
6681
if len(os.Args) == 2 {
6782
port = os.Args[1]
@@ -81,6 +96,7 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) {
8196
},
8297
RecipeList: []supertokens.Recipe{
8398
session.Init(&sessmodels.TypeInput{
99+
ExposeAccessTokenToFrontendInCookieBasedAuth: true,
84100
ErrorHandlers: &sessmodels.ErrorHandlers{
85101
OnUnauthorised: func(message string, req *http.Request, res http.ResponseWriter) error {
86102
res.Header().Set("Content-Type", "text/html; charset=utf-8")
@@ -93,7 +109,61 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) {
93109
Override: &sessmodels.OverrideStruct{
94110
Functions: func(originalImplementation sessmodels.RecipeInterface) sessmodels.RecipeInterface {
95111
ogCNS := *originalImplementation.CreateNewSession
96-
(*originalImplementation.CreateNewSession) = func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) {
112+
*originalImplementation.CreateNewSession = func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) {
113+
if accessTokenPayload == nil {
114+
accessTokenPayload = map[string]interface{}{}
115+
}
116+
accessTokenPayload["customClaim"] = "customValue"
117+
118+
return ogCNS(userID, accessTokenPayload, sessionDataInDatabase, disableAntiCsrf, userContext)
119+
}
120+
return originalImplementation
121+
},
122+
APIs: func(originalImplementation sessmodels.APIInterface) sessmodels.APIInterface {
123+
originalImplementation.RefreshPOST = nil
124+
return originalImplementation
125+
},
126+
},
127+
}),
128+
},
129+
})
130+
131+
if err != nil {
132+
panic(err.Error())
133+
}
134+
} else if maxVersion(supertokens.VERSION, "0.3.1") == supertokens.VERSION && enableJWT {
135+
port := "8080"
136+
if len(os.Args) == 2 {
137+
port = os.Args[1]
138+
}
139+
antiCsrf := "NONE"
140+
if enableAntiCsrf {
141+
antiCsrf = "VIA_TOKEN"
142+
}
143+
err := supertokens.Init(supertokens.TypeInput{
144+
Supertokens: &supertokens.ConnectionInfo{
145+
ConnectionURI: "http://localhost:9000",
146+
},
147+
AppInfo: supertokens.AppInfo{
148+
AppName: "SuperTokens",
149+
APIDomain: "0.0.0.0:" + port,
150+
WebsiteDomain: "http://localhost.org:8080",
151+
},
152+
RecipeList: []supertokens.Recipe{
153+
session.Init(&sessmodels.TypeInput{
154+
ErrorHandlers: &sessmodels.ErrorHandlers{
155+
OnUnauthorised: func(message string, req *http.Request, res http.ResponseWriter) error {
156+
res.Header().Set("Content-Type", "text/html; charset=utf-8")
157+
res.WriteHeader(401)
158+
res.Write([]byte(""))
159+
return nil
160+
},
161+
},
162+
AntiCsrf: &antiCsrf,
163+
Override: &sessmodels.OverrideStruct{
164+
Functions: func(originalImplementation sessmodels.RecipeInterface) sessmodels.RecipeInterface {
165+
ogCNS := *originalImplementation.CreateNewSession
166+
*originalImplementation.CreateNewSession = func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) {
97167
if accessTokenPayload == nil {
98168
accessTokenPayload = map[string]interface{}{}
99169
}
@@ -166,6 +236,8 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) {
166236
setEnableJWT(rw, r)
167237
} else if r.URL.Path == "/login" && r.Method == "POST" {
168238
login(rw, r)
239+
} else if r.URL.Path == "/login-2.18" && r.Method == "POST" {
240+
login218(rw, r)
169241
} else if r.URL.Path == "/beforeeach" && r.Method == "POST" {
170242
beforeeach(rw, r)
171243
} else if r.URL.Path == "/testUserConfig" && r.Method == "POST" {
@@ -273,7 +345,8 @@ func reinitialiseBackendConfig(w http.ResponseWriter, r *http.Request) {
273345

274346
func featureFlag(response http.ResponseWriter, request *http.Request) {
275347
json.NewEncoder(response).Encode(map[string]interface{}{
276-
"sessionJwt": maxVersion(supertokens.VERSION, "0.3.1") == supertokens.VERSION && lastEnableJWTSetting,
348+
"sessionJwt": maxVersion(supertokens.VERSION, "0.3.1") == supertokens.VERSION && lastEnableJWTSetting,
349+
"v3AccessToken": maxVersion(supertokens.VERSION, "0.12.0") == supertokens.VERSION,
277350
})
278351
}
279352

@@ -348,15 +421,48 @@ func updateJwt(response http.ResponseWriter, request *http.Request) {
348421
var body map[string]interface{}
349422
_ = json.NewDecoder(request.Body).Decode(&body)
350423
userSession := session.GetSessionFromRequestContext(request.Context())
351-
userSession.MergeIntoAccessTokenPayload(body)
424+
425+
sessionAccessTokenPayload := userSession.GetAccessTokenPayload()
426+
clearing := map[string]interface{}{}
427+
428+
for k := range sessionAccessTokenPayload {
429+
if !isProtectedProp(k) {
430+
clearing[k] = nil
431+
}
432+
}
433+
434+
for k, v := range body {
435+
clearing[k] = v
436+
}
437+
438+
userSession.MergeIntoAccessTokenPayload(clearing)
352439
json.NewEncoder(response).Encode(userSession.GetAccessTokenPayload())
353440
}
354441

355442
func updateJwtWithHandle(response http.ResponseWriter, request *http.Request) {
356443
var body map[string]interface{}
357444
_ = json.NewDecoder(request.Body).Decode(&body)
358445
userSession := session.GetSessionFromRequestContext(request.Context())
359-
session.MergeIntoAccessTokenPayload(userSession.GetHandle(), body)
446+
sessionInformation, err := session.GetSessionInformation(userSession.GetHandle())
447+
448+
if err != nil {
449+
response.WriteHeader(500)
450+
response.Write([]byte(""))
451+
return
452+
}
453+
454+
customClaimsInPayload := sessionInformation.CustomClaimsInAccessTokenPayload
455+
clearing := map[string]interface{}{}
456+
457+
for k := range customClaimsInPayload {
458+
clearing[k] = nil
459+
}
460+
461+
for k, v := range body {
462+
clearing[k] = v
463+
}
464+
465+
session.MergeIntoAccessTokenPayload(userSession.GetHandle(), clearing)
360466
json.NewEncoder(response).Encode(userSession.GetAccessTokenPayload())
361467
}
362468

@@ -407,6 +513,71 @@ func login(response http.ResponseWriter, request *http.Request) {
407513
response.Write([]byte(sess.GetUserID()))
408514
}
409515

516+
func login218(response http.ResponseWriter, request *http.Request) {
517+
var body map[string]interface{}
518+
_ = json.NewDecoder(request.Body).Decode(&body)
519+
520+
userID := body["userId"].(string)
521+
payload := body["payload"].(map[string]interface{})
522+
523+
querier, err := supertokens.GetNewQuerierInstanceOrThrowError("session")
524+
525+
if err != nil {
526+
response.WriteHeader(500)
527+
response.Write([]byte(""))
528+
return
529+
}
530+
531+
querier.SetApiVersionForTests("2.18")
532+
resp, err := querier.SendPostRequest("/recipe/session", map[string]interface{}{
533+
"userId": userID,
534+
"userDataInJWT": payload,
535+
"userDataInDatabase": map[string]interface{}{},
536+
"enableAntiCsrf": false,
537+
})
538+
539+
if err != nil {
540+
response.WriteHeader(500)
541+
response.Write([]byte(""))
542+
return
543+
}
544+
545+
querier.SetApiVersionForTests("")
546+
547+
responseByte, err := json.Marshal(resp)
548+
if err != nil {
549+
response.WriteHeader(500)
550+
response.Write([]byte(""))
551+
return
552+
}
553+
var sessionResp sessmodels.CreateOrRefreshAPIResponse
554+
err = json.Unmarshal(responseByte, &sessionResp)
555+
if err != nil {
556+
response.WriteHeader(500)
557+
response.Write([]byte(""))
558+
return
559+
}
560+
561+
legacyAccessToken := sessionResp.AccessToken.Token
562+
legacyRefreshToken := sessionResp.RefreshToken.Token
563+
564+
frontTokenJson := json.NewEncoder(response).Encode(map[string]interface{}{
565+
"uid": userID,
566+
"ate": session.GetCurrTimeInMS() + 3600000,
567+
"up": payload,
568+
})
569+
570+
parsed, _ := json.Marshal(frontTokenJson)
571+
data := []byte(parsed)
572+
573+
frontToken := base64.StdEncoding.EncodeToString(data)
574+
575+
response.Header().Set("st-access-token", legacyAccessToken)
576+
response.Header().Set("st-refresh-token", legacyRefreshToken)
577+
response.Header().Set("front-token", frontToken)
578+
response.Write([]byte(""))
579+
}
580+
410581
func fail(w http.ResponseWriter, r *http.Request) {
411582
w.WriteHeader(404)
412583
w.Write([]byte(""))

0 commit comments

Comments
 (0)