Skip to content

Commit a4a14cc

Browse files
committed
Refactor based on PR review
1 parent 7e38104 commit a4a14cc

File tree

3 files changed

+57
-57
lines changed

3 files changed

+57
-57
lines changed

recipe/session/recipeImplementation.go

Lines changed: 43 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/supertokens/supertokens-golang/recipe/session/sessmodels"
2727
"github.com/supertokens/supertokens-golang/supertokens"
2828
"reflect"
29+
"sync"
2930
"time"
3031
)
3132

@@ -41,56 +42,52 @@ var protectedProps = []string{
4142

4243
var JWKCacheMaxAgeInMs int64 = 60000
4344
var JWKRefreshRateLimit = 500
44-
45-
// Maintains a map of the core path to the result
4645
var jwksCache *sessmodels.GetJWKSResult = nil
46+
var mutex sync.RWMutex
47+
48+
func getJWKSFromCacheIfPresent() *sessmodels.GetJWKSResult {
49+
mutex.RLock()
50+
defer mutex.RUnlock()
51+
if jwksCache != nil {
52+
// This means that we have valid JWKs for the given core path
53+
// We check if we need to refresh before returning
54+
currentTime := time.Now().UnixNano() / int64(time.Millisecond)
55+
56+
// This means that the value in cache is not expired, in this case we return the cached value
57+
//
58+
// Note that this also means that the SDK will not try to query any other Core (if there are multiple)
59+
// if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
60+
// from the cores again after the entry in the cache is expired
61+
if (currentTime - jwksCache.LastFetched) < JWKCacheMaxAgeInMs {
62+
if supertokens.IsRunningInTestMode() {
63+
returnedFromCache = true
64+
}
65+
66+
return jwksCache
67+
}
68+
}
4769

48-
func getJWKS() sessmodels.GetJWKSResult {
70+
return nil
71+
}
72+
73+
func getJWKS() (*keyfunc.JWKS, error) {
4974
corePaths := supertokens.GetAllCoreUrlsForPath("/.well-known/jwks.json")
5075

5176
if len(corePaths) == 0 {
52-
return sessmodels.GetJWKSResult{
53-
JWKS: nil,
54-
Error: defaultErrors.New("No SuperTokens core available to query. Please pass supertokens > connectionURI to the init function, or override all the functions of the recipe you are using."),
55-
LastFetched: 0,
56-
}
77+
return nil, defaultErrors.New("No SuperTokens core available to query. Please pass supertokens > connectionURI to the init function, or override all the functions of the recipe you are using.")
5778
}
5879

59-
var lastError error
80+
resultFromCache := getJWKSFromCacheIfPresent()
6081

61-
for _, path := range corePaths {
62-
// Here we dont need to check if cached result had an error because we only add to cache
63-
// if the JWKS result was successful
64-
if jwksCache != nil {
65-
// This means that we have valid JWKs for the given core path
66-
// We check if we need to refresh before returning
67-
currentTime := time.Now().UnixNano() / int64(time.Millisecond)
68-
69-
// This means that the value in cache is not expired, in this case we return the cached value
70-
//
71-
// Note that this also means that the SDK will not try to query any other Core (if there are multiple)
72-
// if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
73-
// from the cores again after the entry in the cache is expired
74-
if (currentTime - jwksCache.LastFetched) < JWKCacheMaxAgeInMs {
75-
if supertokens.IsRunningInTestMode() {
76-
returnedFromCache = true
77-
}
78-
79-
return *jwksCache
80-
}
82+
if resultFromCache != nil {
83+
return resultFromCache.JWKS, nil
84+
}
8185

82-
// This means that the value in cache is expired, we clear from cache and proceed
83-
// as if it was never cached because that would be the equivalent of refreshing
84-
//
85-
// This has the added benefit where if there are multiple cores [Core1, Core2] and initially
86-
// Core1 was down (so the cache only has the result from Core2). When the cache expires the SDK
87-
// will try to re-fetch for Core1 and will return that result (and save to cache) if Core1 is now up
88-
jwksCache = nil
89-
if supertokens.IsRunningInTestMode() {
90-
deleteFromCacheCount++
91-
}
92-
}
86+
var lastError error
9387

88+
mutex.Lock()
89+
defer mutex.Unlock()
90+
for _, path := range corePaths {
9491
if supertokens.IsRunningInTestMode() {
9592
urlsAttemptedForJWKSFetch = append(urlsAttemptedForJWKSFetch, path)
9693
}
@@ -119,18 +116,14 @@ func getJWKS() sessmodels.GetJWKSResult {
119116
returnedFromCache = false
120117
}
121118

122-
return jwksResult
119+
return jwksResult.JWKS, nil
123120
}
124121

125122
lastError = jwksError
126123
}
127124

128125
// This means that fetching from all cores failed
129-
return sessmodels.GetJWKSResult{
130-
JWKS: nil,
131-
Error: lastError,
132-
LastFetched: 0,
133-
}
126+
return nil, lastError
134127
}
135128

136129
/**
@@ -145,13 +138,13 @@ func GetCombinedJWKS() (*keyfunc.JWKS, error) {
145138
urlsAttemptedForJWKSFetch = []string{}
146139
}
147140

148-
jwksResult := getJWKS()
141+
jwksResult, err := getJWKS()
149142

150-
if jwksResult.Error != nil {
151-
return nil, jwksResult.Error
143+
if err != nil {
144+
return nil, err
152145
}
153146

154-
return jwksResult.JWKS, nil
147+
return jwksResult, nil
155148
}
156149

157150
func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.TypeNormalisedInput, appInfo supertokens.NormalisedAppinfo) sessmodels.RecipeInterface {

recipe/session/session_test.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,13 +1335,23 @@ func TestThatJWKSResultIsRefreshedProperly(t *testing.T) {
13351335
t.Error(err.Error())
13361336
}
13371337

1338-
jwksBefore := getJWKS()
1339-
beforeKids := jwksBefore.JWKS.KIDs()
1338+
jwksBefore, err := getJWKS()
1339+
1340+
if err != nil {
1341+
t.Error(err.Error())
1342+
}
1343+
1344+
beforeKids := jwksBefore.KIDs()
13401345

13411346
time.Sleep(3 * time.Second)
13421347

1343-
jwksAfter := getJWKS()
1344-
afterKids := jwksAfter.JWKS.KIDs()
1348+
jwksAfter, err := getJWKS()
1349+
1350+
if err != nil {
1351+
t.Error(err.Error())
1352+
}
1353+
1354+
afterKids := jwksAfter.KIDs()
13451355
var newKeys []string
13461356

13471357
for _, key := range afterKids {
@@ -1562,7 +1572,6 @@ func TestJWKSCacheLogic(t *testing.T) {
15621572
t.Error(err.Error())
15631573
}
15641574

1565-
assert.Equal(t, deleteFromCacheCount, 1)
15661575
assert.NotNil(t, jwksCache)
15671576

15681577
JWKRefreshRateLimit = originalRefreshlimit

recipe/session/testingUtils.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,13 @@ import (
2424

2525
// Testing constants
2626
var didGetSessionCallCore = false
27-
var deleteFromCacheCount = 0
2827
var returnedFromCache = false
2928
var urlsAttemptedForJWKSFetch []string
3029

3130
func resetAll() {
3231
supertokens.ResetForTest()
3332
ResetForTest()
3433
didGetSessionCallCore = false
35-
deleteFromCacheCount = 0
3634
returnedFromCache = false
3735
urlsAttemptedForJWKSFetch = []string{}
3836
jwksCache = nil

0 commit comments

Comments
 (0)