Skip to content

Commit 5f5d148

Browse files
Merge pull request #293 from supertokens/flow-fixes
refactor: Refactor JWKs fetching logic for session recipe
2 parents 485887c + a4a14cc commit 5f5d148

File tree

4 files changed

+476
-42
lines changed

4 files changed

+476
-42
lines changed

recipe/session/recipeImplementation.go

Lines changed: 81 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/supertokens/supertokens-golang/recipe/session/sessmodels"
2727
"github.com/supertokens/supertokens-golang/supertokens"
2828
"reflect"
29+
"sync"
2930
"time"
3031
)
3132

@@ -39,31 +40,90 @@ var protectedProps = []string{
3940
"antiCsrfToken",
4041
}
4142

42-
var JWKCacheMaxAgeInMs = 60000
43+
var JWKCacheMaxAgeInMs int64 = 60000
4344
var JWKRefreshRateLimit = 500
44-
var jwksResults []sessmodels.GetJWKSResult
45+
var jwksCache *sessmodels.GetJWKSResult = nil
46+
var mutex sync.RWMutex
47+
48+
func getJWKSFromCacheIfPresent() *sessmodels.GetJWKSResult {
49+
mutex.RLock()
50+
defer mutex.RUnlock()
51+
if jwksCache != nil {
52+
// This means that we have valid JWKs for the given core path
53+
// We check if we need to refresh before returning
54+
currentTime := time.Now().UnixNano() / int64(time.Millisecond)
55+
56+
// This means that the value in cache is not expired, in this case we return the cached value
57+
//
58+
// Note that this also means that the SDK will not try to query any other Core (if there are multiple)
59+
// if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
60+
// from the cores again after the entry in the cache is expired
61+
if (currentTime - jwksCache.LastFetched) < JWKCacheMaxAgeInMs {
62+
if supertokens.IsRunningInTestMode() {
63+
returnedFromCache = true
64+
}
4565

46-
func getJWKS() []sessmodels.GetJWKSResult {
47-
result := []sessmodels.GetJWKSResult{}
66+
return jwksCache
67+
}
68+
}
69+
70+
return nil
71+
}
72+
73+
func getJWKS() (*keyfunc.JWKS, error) {
4874
corePaths := supertokens.GetAllCoreUrlsForPath("/.well-known/jwks.json")
4975

76+
if len(corePaths) == 0 {
77+
return nil, defaultErrors.New("No SuperTokens core available to query. Please pass supertokens > connectionURI to the init function, or override all the functions of the recipe you are using.")
78+
}
79+
80+
resultFromCache := getJWKSFromCacheIfPresent()
81+
82+
if resultFromCache != nil {
83+
return resultFromCache.JWKS, nil
84+
}
85+
86+
var lastError error
87+
88+
mutex.Lock()
89+
defer mutex.Unlock()
5090
for _, path := range corePaths {
51-
// RefreshUnknownKID - Fetch JWKS again if the kid in the header of the JWT does not match any in cache
52-
// RefreshRateLimit - Only allow one re-fetch every 500 milliseconds
53-
// RefreshInterval - Refreshes should occur every 600 seconds
54-
jwks, err := keyfunc.Get(path, keyfunc.Options{
91+
if supertokens.IsRunningInTestMode() {
92+
urlsAttemptedForJWKSFetch = append(urlsAttemptedForJWKSFetch, path)
93+
}
94+
95+
// RefreshUnknownKID - Fetch JWKS again if the kid in the header of the JWT does not match any in
96+
// the keyfunc library's cache
97+
jwks, jwksError := keyfunc.Get(path, keyfunc.Options{
5598
RefreshUnknownKID: true,
56-
RefreshRateLimit: time.Millisecond * time.Duration(JWKRefreshRateLimit),
57-
RefreshInterval: time.Millisecond * time.Duration(JWKCacheMaxAgeInMs),
5899
})
59100

60-
result = append(result, sessmodels.GetJWKSResult{
61-
JWKS: jwks,
62-
Error: err,
63-
})
101+
if jwksError == nil {
102+
jwksResult := sessmodels.GetJWKSResult{
103+
JWKS: jwks,
104+
Error: jwksError,
105+
LastFetched: time.Now().UnixNano() / int64(time.Millisecond),
106+
}
107+
108+
// Dont add to cache if there is an error to keep the logic of checking cache simple
109+
//
110+
// This also has the added benefit where if initially the request failed because the core
111+
// was down and then it comes back up, the next time it will try to request that core again
112+
// after the cache has expired
113+
jwksCache = &jwksResult
114+
115+
if supertokens.IsRunningInTestMode() {
116+
returnedFromCache = false
117+
}
118+
119+
return jwksResult.JWKS, nil
120+
}
121+
122+
lastError = jwksError
64123
}
65124

66-
return result
125+
// This means that fetching from all cores failed
126+
return nil, lastError
67127
}
68128

69129
/**
@@ -74,29 +134,21 @@ Every core instance a backend is connected to is expected to connect to the same
74134
token verification. Otherwise, the result of session verification would depend on which core is currently available.
75135
*/
76136
func GetCombinedJWKS() (*keyfunc.JWKS, error) {
77-
var lastError error
78-
79-
if len(jwksResults) == 0 {
80-
return nil, defaultErrors.New("No SuperTokens core available to query. Please pass supertokens > connectionURI to the init function, or override all the functions of the recipe you are using.")
137+
if supertokens.IsRunningInTestMode() {
138+
urlsAttemptedForJWKSFetch = []string{}
81139
}
82140

83-
for _, jwk := range jwksResults {
84-
jwksResult := jwk.JWKS
85-
err := jwk.Error
141+
jwksResult, err := getJWKS()
86142

87-
if err != nil {
88-
lastError = err
89-
} else {
90-
return jwksResult, nil
91-
}
143+
if err != nil {
144+
return nil, err
92145
}
93146

94-
return nil, lastError
147+
return jwksResult, nil
95148
}
96149

97150
func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.TypeNormalisedInput, appInfo supertokens.NormalisedAppinfo) sessmodels.RecipeInterface {
98151
var result sessmodels.RecipeInterface
99-
jwksResults = getJWKS()
100152

101153
createNewSession := func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) {
102154
supertokens.LogDebugMessage("createNewSession: Started")

0 commit comments

Comments
 (0)