1717package main
1818
1919import (
20+ "encoding/base64"
2021 "encoding/json"
2122 "fmt"
2223 "io/ioutil"
@@ -58,10 +59,24 @@ func maxVersion(version1 string, version2 string) string {
5859 return version2
5960}
6061
62+ func isProtectedProp (prop string ) bool {
63+ protectedProps := []string {
64+ "sub" ,
65+ "iat" ,
66+ "exp" ,
67+ "sessionHandle" ,
68+ "parentRefreshTokenHash1" ,
69+ "refreshTokenHash1" ,
70+ "antiCsrfToken" ,
71+ }
72+
73+ return supertokens .DoesSliceContainString (prop , protectedProps )
74+ }
75+
6176var routes * http.Handler
6277
6378func callSTInit (enableAntiCsrf bool , enableJWT bool , jwtPropertyName string ) {
64- if maxVersion (supertokens .VERSION , "0.3.1 " ) == supertokens .VERSION && enableJWT {
79+ if maxVersion (supertokens .VERSION , "0.12.0 " ) == supertokens .VERSION && enableJWT {
6580 port := "8080"
6681 if len (os .Args ) == 2 {
6782 port = os .Args [1 ]
@@ -81,6 +96,7 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) {
8196 },
8297 RecipeList : []supertokens.Recipe {
8398 session .Init (& sessmodels.TypeInput {
99+ ExposeAccessTokenToFrontendInCookieBasedAuth : true ,
84100 ErrorHandlers : & sessmodels.ErrorHandlers {
85101 OnUnauthorised : func (message string , req * http.Request , res http.ResponseWriter ) error {
86102 res .Header ().Set ("Content-Type" , "text/html; charset=utf-8" )
@@ -93,7 +109,61 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) {
93109 Override : & sessmodels.OverrideStruct {
94110 Functions : func (originalImplementation sessmodels.RecipeInterface ) sessmodels.RecipeInterface {
95111 ogCNS := * originalImplementation .CreateNewSession
96- (* originalImplementation .CreateNewSession ) = func (userID string , accessTokenPayload map [string ]interface {}, sessionDataInDatabase map [string ]interface {}, disableAntiCsrf * bool , userContext supertokens.UserContext ) (sessmodels.SessionContainer , error ) {
112+ * originalImplementation .CreateNewSession = func (userID string , accessTokenPayload map [string ]interface {}, sessionDataInDatabase map [string ]interface {}, disableAntiCsrf * bool , userContext supertokens.UserContext ) (sessmodels.SessionContainer , error ) {
113+ if accessTokenPayload == nil {
114+ accessTokenPayload = map [string ]interface {}{}
115+ }
116+ accessTokenPayload ["customClaim" ] = "customValue"
117+
118+ return ogCNS (userID , accessTokenPayload , sessionDataInDatabase , disableAntiCsrf , userContext )
119+ }
120+ return originalImplementation
121+ },
122+ APIs : func (originalImplementation sessmodels.APIInterface ) sessmodels.APIInterface {
123+ originalImplementation .RefreshPOST = nil
124+ return originalImplementation
125+ },
126+ },
127+ }),
128+ },
129+ })
130+
131+ if err != nil {
132+ panic (err .Error ())
133+ }
134+ } else if maxVersion (supertokens .VERSION , "0.3.1" ) == supertokens .VERSION && enableJWT {
135+ port := "8080"
136+ if len (os .Args ) == 2 {
137+ port = os .Args [1 ]
138+ }
139+ antiCsrf := "NONE"
140+ if enableAntiCsrf {
141+ antiCsrf = "VIA_TOKEN"
142+ }
143+ err := supertokens .Init (supertokens.TypeInput {
144+ Supertokens : & supertokens.ConnectionInfo {
145+ ConnectionURI : "http://localhost:9000" ,
146+ },
147+ AppInfo : supertokens.AppInfo {
148+ AppName : "SuperTokens" ,
149+ APIDomain : "0.0.0.0:" + port ,
150+ WebsiteDomain : "http://localhost.org:8080" ,
151+ },
152+ RecipeList : []supertokens.Recipe {
153+ session .Init (& sessmodels.TypeInput {
154+ ErrorHandlers : & sessmodels.ErrorHandlers {
155+ OnUnauthorised : func (message string , req * http.Request , res http.ResponseWriter ) error {
156+ res .Header ().Set ("Content-Type" , "text/html; charset=utf-8" )
157+ res .WriteHeader (401 )
158+ res .Write ([]byte ("" ))
159+ return nil
160+ },
161+ },
162+ AntiCsrf : & antiCsrf ,
163+ Override : & sessmodels.OverrideStruct {
164+ Functions : func (originalImplementation sessmodels.RecipeInterface ) sessmodels.RecipeInterface {
165+ ogCNS := * originalImplementation .CreateNewSession
166+ * originalImplementation .CreateNewSession = func (userID string , accessTokenPayload map [string ]interface {}, sessionDataInDatabase map [string ]interface {}, disableAntiCsrf * bool , userContext supertokens.UserContext ) (sessmodels.SessionContainer , error ) {
97167 if accessTokenPayload == nil {
98168 accessTokenPayload = map [string ]interface {}{}
99169 }
@@ -166,6 +236,8 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) {
166236 setEnableJWT (rw , r )
167237 } else if r .URL .Path == "/login" && r .Method == "POST" {
168238 login (rw , r )
239+ } else if r .URL .Path == "/login-2.18" && r .Method == "POST" {
240+ login218 (rw , r )
169241 } else if r .URL .Path == "/beforeeach" && r .Method == "POST" {
170242 beforeeach (rw , r )
171243 } else if r .URL .Path == "/testUserConfig" && r .Method == "POST" {
@@ -273,7 +345,8 @@ func reinitialiseBackendConfig(w http.ResponseWriter, r *http.Request) {
273345
274346func featureFlag (response http.ResponseWriter , request * http.Request ) {
275347 json .NewEncoder (response ).Encode (map [string ]interface {}{
276- "sessionJwt" : maxVersion (supertokens .VERSION , "0.3.1" ) == supertokens .VERSION && lastEnableJWTSetting ,
348+ "sessionJwt" : maxVersion (supertokens .VERSION , "0.3.1" ) == supertokens .VERSION && lastEnableJWTSetting ,
349+ "v3AccessToken" : maxVersion (supertokens .VERSION , "0.12.0" ) == supertokens .VERSION ,
277350 })
278351}
279352
@@ -348,15 +421,48 @@ func updateJwt(response http.ResponseWriter, request *http.Request) {
348421 var body map [string ]interface {}
349422 _ = json .NewDecoder (request .Body ).Decode (& body )
350423 userSession := session .GetSessionFromRequestContext (request .Context ())
351- userSession .MergeIntoAccessTokenPayload (body )
424+
425+ sessionAccessTokenPayload := userSession .GetAccessTokenPayload ()
426+ clearing := map [string ]interface {}{}
427+
428+ for k := range sessionAccessTokenPayload {
429+ if ! isProtectedProp (k ) {
430+ clearing [k ] = nil
431+ }
432+ }
433+
434+ for k , v := range body {
435+ clearing [k ] = v
436+ }
437+
438+ userSession .MergeIntoAccessTokenPayload (clearing )
352439 json .NewEncoder (response ).Encode (userSession .GetAccessTokenPayload ())
353440}
354441
355442func updateJwtWithHandle (response http.ResponseWriter , request * http.Request ) {
356443 var body map [string ]interface {}
357444 _ = json .NewDecoder (request .Body ).Decode (& body )
358445 userSession := session .GetSessionFromRequestContext (request .Context ())
359- session .MergeIntoAccessTokenPayload (userSession .GetHandle (), body )
446+ sessionInformation , err := session .GetSessionInformation (userSession .GetHandle ())
447+
448+ if err != nil {
449+ response .WriteHeader (500 )
450+ response .Write ([]byte ("" ))
451+ return
452+ }
453+
454+ customClaimsInPayload := sessionInformation .CustomClaimsInAccessTokenPayload
455+ clearing := map [string ]interface {}{}
456+
457+ for k := range customClaimsInPayload {
458+ clearing [k ] = nil
459+ }
460+
461+ for k , v := range body {
462+ clearing [k ] = v
463+ }
464+
465+ session .MergeIntoAccessTokenPayload (userSession .GetHandle (), clearing )
360466 json .NewEncoder (response ).Encode (userSession .GetAccessTokenPayload ())
361467}
362468
@@ -407,6 +513,71 @@ func login(response http.ResponseWriter, request *http.Request) {
407513 response .Write ([]byte (sess .GetUserID ()))
408514}
409515
516+ func login218 (response http.ResponseWriter , request * http.Request ) {
517+ var body map [string ]interface {}
518+ _ = json .NewDecoder (request .Body ).Decode (& body )
519+
520+ userID := body ["userId" ].(string )
521+ payload := body ["payload" ].(map [string ]interface {})
522+
523+ querier , err := supertokens .GetNewQuerierInstanceOrThrowError ("session" )
524+
525+ if err != nil {
526+ response .WriteHeader (500 )
527+ response .Write ([]byte ("" ))
528+ return
529+ }
530+
531+ querier .SetApiVersionForTests ("2.18" )
532+ resp , err := querier .SendPostRequest ("/recipe/session" , map [string ]interface {}{
533+ "userId" : userID ,
534+ "userDataInJWT" : payload ,
535+ "userDataInDatabase" : map [string ]interface {}{},
536+ "enableAntiCsrf" : false ,
537+ })
538+
539+ if err != nil {
540+ response .WriteHeader (500 )
541+ response .Write ([]byte ("" ))
542+ return
543+ }
544+
545+ querier .SetApiVersionForTests ("" )
546+
547+ responseByte , err := json .Marshal (resp )
548+ if err != nil {
549+ response .WriteHeader (500 )
550+ response .Write ([]byte ("" ))
551+ return
552+ }
553+ var sessionResp sessmodels.CreateOrRefreshAPIResponse
554+ err = json .Unmarshal (responseByte , & sessionResp )
555+ if err != nil {
556+ response .WriteHeader (500 )
557+ response .Write ([]byte ("" ))
558+ return
559+ }
560+
561+ legacyAccessToken := sessionResp .AccessToken .Token
562+ legacyRefreshToken := sessionResp .RefreshToken .Token
563+
564+ frontTokenJson := json .NewEncoder (response ).Encode (map [string ]interface {}{
565+ "uid" : userID ,
566+ "ate" : session .GetCurrTimeInMS () + 3600000 ,
567+ "up" : payload ,
568+ })
569+
570+ parsed , _ := json .Marshal (frontTokenJson )
571+ data := []byte (parsed )
572+
573+ frontToken := base64 .StdEncoding .EncodeToString (data )
574+
575+ response .Header ().Set ("st-access-token" , legacyAccessToken )
576+ response .Header ().Set ("st-refresh-token" , legacyRefreshToken )
577+ response .Header ().Set ("front-token" , frontToken )
578+ response .Write ([]byte ("" ))
579+ }
580+
410581func fail (w http.ResponseWriter , r * http.Request ) {
411582 w .WriteHeader (404 )
412583 w .Write ([]byte ("" ))
0 commit comments