Skip to content

Commit 7e38104

Browse files
committed
Update logic as per PR comments
1 parent 9577b9b commit 7e38104

File tree

2 files changed

+59
-77
lines changed

2 files changed

+59
-77
lines changed

recipe/session/recipeImplementation.go

Lines changed: 55 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,19 @@ var JWKRefreshRateLimit = 500
4545
// Maintains a map of the core path to the result
4646
var jwksCache *sessmodels.GetJWKSResult = nil
4747

48-
func getJWKS() []sessmodels.GetJWKSFunctionObject {
49-
result := []sessmodels.GetJWKSFunctionObject{}
48+
func getJWKS() sessmodels.GetJWKSResult {
5049
corePaths := supertokens.GetAllCoreUrlsForPath("/.well-known/jwks.json")
5150

51+
if len(corePaths) == 0 {
52+
return sessmodels.GetJWKSResult{
53+
JWKS: nil,
54+
Error: defaultErrors.New("No SuperTokens core available to query. Please pass supertokens > connectionURI to the init function, or override all the functions of the recipe you are using."),
55+
LastFetched: 0,
56+
}
57+
}
58+
59+
var lastError error
60+
5261
for _, path := range corePaths {
5362
// Here we dont need to check if cached result had an error because we only add to cache
5463
// if the JWKS result was successful
@@ -57,86 +66,71 @@ func getJWKS() []sessmodels.GetJWKSFunctionObject {
5766
// We check if we need to refresh before returning
5867
currentTime := time.Now().UnixNano() / int64(time.Millisecond)
5968

60-
// This means that the value in cache is not expired, in this case we return a function that simply
61-
// returns the cached keys
69+
// This means that the value in cache is not expired, in this case we return the cached value
6270
//
6371
// Note that this also means that the SDK will not try to query any other Core (if there are multiple)
64-
// if it has even a single valid cache entry from one of the core URLs. It will only attempt to fetch
72+
// if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
6573
// from the cores again after the entry in the cache is expired
6674
if (currentTime - jwksCache.LastFetched) < JWKCacheMaxAgeInMs {
67-
finalResult := []sessmodels.GetJWKSFunctionObject{}
68-
69-
finalResult = append(finalResult, sessmodels.GetJWKSFunctionObject{
70-
Fn: func(_ string) sessmodels.GetJWKSResult {
71-
if supertokens.IsRunningInTestMode() {
72-
returnedFromCache = true
73-
}
74-
75-
return *jwksCache
76-
},
77-
Path: path,
78-
})
75+
if supertokens.IsRunningInTestMode() {
76+
returnedFromCache = true
77+
}
7978

80-
return finalResult
79+
return *jwksCache
8180
}
8281

8382
// This means that the value in cache is expired, we clear from cache and proceed
8483
// as if it was never cached because that would be the equivalent of refreshing
8584
//
8685
// This has the added benefit where if there are multiple cores [Core1, Core2] and initially
87-
// Core1 was down (so the cache only has a result for Core2). When Core2's cache expires the SDK
86+
// Core1 was down (so the cache only has the result from Core2). When the cache expires the SDK
8887
// will try to re-fetch for Core1 and will return that result (and save to cache) if Core1 is now up
8988
jwksCache = nil
9089
if supertokens.IsRunningInTestMode() {
9190
deleteFromCacheCount++
9291
}
9392
}
9493

95-
// We need to also save the path of the JWKS request and pass it to the function this way because
96-
// golang only saves a reference to the last value set to path when this function actually is called
97-
// so all the functions in the slice would end up calling the last core host
98-
//
99-
// Doing it this way makes sure that all individual hosts are called
100-
result = append(result, sessmodels.GetJWKSFunctionObject{
101-
Fn: func(inputPath string) sessmodels.GetJWKSResult {
102-
if supertokens.IsRunningInTestMode() {
103-
urlsAttemptedForJWKSFetch = append(urlsAttemptedForJWKSFetch, inputPath)
104-
}
94+
if supertokens.IsRunningInTestMode() {
95+
urlsAttemptedForJWKSFetch = append(urlsAttemptedForJWKSFetch, path)
96+
}
10597

106-
// RefreshUnknownKID - Fetch JWKS again if the kid in the header of the JWT does not match any in
107-
// the keyfunc library's cache
108-
jwks, err := keyfunc.Get(inputPath, keyfunc.Options{
109-
RefreshUnknownKID: true,
110-
})
98+
// RefreshUnknownKID - Fetch JWKS again if the kid in the header of the JWT does not match any in
99+
// the keyfunc library's cache
100+
jwks, jwksError := keyfunc.Get(path, keyfunc.Options{
101+
RefreshUnknownKID: true,
102+
})
111103

112-
jwksResult := sessmodels.GetJWKSResult{
113-
JWKS: jwks,
114-
Error: err,
115-
LastFetched: time.Now().UnixNano() / int64(time.Millisecond),
116-
}
104+
if jwksError == nil {
105+
jwksResult := sessmodels.GetJWKSResult{
106+
JWKS: jwks,
107+
Error: jwksError,
108+
LastFetched: time.Now().UnixNano() / int64(time.Millisecond),
109+
}
117110

118-
// Dont add to cache if there is an error to keep the logic of checking cache simple
119-
// This means that for multiple cores, the only item we add to cache would be the first
120-
// core that returned keys without an error
121-
//
122-
// This also has the added benefit where if initially the request failed because the core
123-
// was down and then it comes back up, the next time it will try to request that core again
124-
// after the cache has expired
125-
if err == nil {
126-
jwksCache = &jwksResult
127-
}
111+
// Dont add to cache if there is an error to keep the logic of checking cache simple
112+
//
113+
// This also has the added benefit where if initially the request failed because the core
114+
// was down and then it comes back up, the next time it will try to request that core again
115+
// after the cache has expired
116+
jwksCache = &jwksResult
128117

129-
if supertokens.IsRunningInTestMode() {
130-
returnedFromCache = false
131-
}
118+
if supertokens.IsRunningInTestMode() {
119+
returnedFromCache = false
120+
}
132121

133-
return jwksResult
134-
},
135-
Path: path,
136-
})
122+
return jwksResult
123+
}
124+
125+
lastError = jwksError
137126
}
138127

139-
return result
128+
// This means that fetching from all cores failed
129+
return sessmodels.GetJWKSResult{
130+
JWKS: nil,
131+
Error: lastError,
132+
LastFetched: 0,
133+
}
140134
}
141135

142136
/**
@@ -147,29 +141,17 @@ Every core instance a backend is connected to is expected to connect to the same
147141
token verification. Otherwise, the result of session verification would depend on which core is currently available.
148142
*/
149143
func GetCombinedJWKS() (*keyfunc.JWKS, error) {
150-
var lastError error
151-
jwksObjects := getJWKS()
152-
153144
if supertokens.IsRunningInTestMode() {
154145
urlsAttemptedForJWKSFetch = []string{}
155146
}
156147

157-
if len(jwksObjects) == 0 {
158-
return nil, defaultErrors.New("No SuperTokens core available to query. Please pass supertokens > connectionURI to the init function, or override all the functions of the recipe you are using.")
159-
}
160-
161-
for _, jwkObject := range jwksObjects {
162-
jwksResult := jwkObject.Fn(jwkObject.Path)
163-
err := jwksResult.Error
148+
jwksResult := getJWKS()
164149

165-
if err != nil {
166-
lastError = err
167-
} else {
168-
return jwksResult.JWKS, nil
169-
}
150+
if jwksResult.Error != nil {
151+
return nil, jwksResult.Error
170152
}
171153

172-
return nil, lastError
154+
return jwksResult.JWKS, nil
173155
}
174156

175157
func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.TypeNormalisedInput, appInfo supertokens.NormalisedAppinfo) sessmodels.RecipeInterface {

recipe/session/session_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,13 +1335,13 @@ func TestThatJWKSResultIsRefreshedProperly(t *testing.T) {
13351335
t.Error(err.Error())
13361336
}
13371337

1338-
jwksBefore := getJWKS()[0]
1339-
beforeKids := jwksBefore.Fn(jwksBefore.Path).JWKS.KIDs()
1338+
jwksBefore := getJWKS()
1339+
beforeKids := jwksBefore.JWKS.KIDs()
13401340

13411341
time.Sleep(3 * time.Second)
13421342

1343-
jwksAfter := getJWKS()[0]
1344-
afterKids := jwksAfter.Fn(jwksAfter.Path).JWKS.KIDs()
1343+
jwksAfter := getJWKS()
1344+
afterKids := jwksAfter.JWKS.KIDs()
13451345
var newKeys []string
13461346

13471347
for _, key := range afterKids {

0 commit comments

Comments
 (0)