@@ -26,6 +26,7 @@ import (
2626 "github.com/supertokens/supertokens-golang/recipe/session/sessmodels"
2727 "github.com/supertokens/supertokens-golang/supertokens"
2828 "reflect"
29+ "sync"
2930 "time"
3031)
3132
@@ -41,56 +42,52 @@ var protectedProps = []string{
4142
4243var JWKCacheMaxAgeInMs int64 = 60000
4344var JWKRefreshRateLimit = 500
44-
45- // Maintains a map of the core path to the result
4645var jwksCache * sessmodels.GetJWKSResult = nil
46+ var mutex sync.RWMutex
47+
48+ func getJWKSFromCacheIfPresent () * sessmodels.GetJWKSResult {
49+ mutex .RLock ()
50+ defer mutex .RUnlock ()
51+ if jwksCache != nil {
52+ // This means that we have valid JWKs for the given core path
53+ // We check if we need to refresh before returning
54+ currentTime := time .Now ().UnixNano () / int64 (time .Millisecond )
55+
56+ // This means that the value in cache is not expired, in this case we return the cached value
57+ //
58+ // Note that this also means that the SDK will not try to query any other Core (if there are multiple)
59+ // if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
60+ // from the cores again after the entry in the cache is expired
61+ if (currentTime - jwksCache .LastFetched ) < JWKCacheMaxAgeInMs {
62+ if supertokens .IsRunningInTestMode () {
63+ returnedFromCache = true
64+ }
65+
66+ return jwksCache
67+ }
68+ }
4769
48- func getJWKS () sessmodels.GetJWKSResult {
70+ return nil
71+ }
72+
73+ func getJWKS () (* keyfunc.JWKS , error ) {
4974 corePaths := supertokens .GetAllCoreUrlsForPath ("/.well-known/jwks.json" )
5075
5176 if len (corePaths ) == 0 {
52- return sessmodels.GetJWKSResult {
53- JWKS : nil ,
54- Error : defaultErrors .New ("No SuperTokens core available to query. Please pass supertokens > connectionURI to the init function, or override all the functions of the recipe you are using." ),
55- LastFetched : 0 ,
56- }
77+ return nil , defaultErrors .New ("No SuperTokens core available to query. Please pass supertokens > connectionURI to the init function, or override all the functions of the recipe you are using." )
5778 }
5879
59- var lastError error
80+ resultFromCache := getJWKSFromCacheIfPresent ()
6081
61- for _ , path := range corePaths {
62- // Here we dont need to check if cached result had an error because we only add to cache
63- // if the JWKS result was successful
64- if jwksCache != nil {
65- // This means that we have valid JWKs for the given core path
66- // We check if we need to refresh before returning
67- currentTime := time .Now ().UnixNano () / int64 (time .Millisecond )
68-
69- // This means that the value in cache is not expired, in this case we return the cached value
70- //
71- // Note that this also means that the SDK will not try to query any other Core (if there are multiple)
72- // if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
73- // from the cores again after the entry in the cache is expired
74- if (currentTime - jwksCache .LastFetched ) < JWKCacheMaxAgeInMs {
75- if supertokens .IsRunningInTestMode () {
76- returnedFromCache = true
77- }
78-
79- return * jwksCache
80- }
82+ if resultFromCache != nil {
83+ return resultFromCache .JWKS , nil
84+ }
8185
82- // This means that the value in cache is expired, we clear from cache and proceed
83- // as if it was never cached because that would be the equivalent of refreshing
84- //
85- // This has the added benefit where if there are multiple cores [Core1, Core2] and initially
86- // Core1 was down (so the cache only has the result from Core2). When the cache expires the SDK
87- // will try to re-fetch for Core1 and will return that result (and save to cache) if Core1 is now up
88- jwksCache = nil
89- if supertokens .IsRunningInTestMode () {
90- deleteFromCacheCount ++
91- }
92- }
86+ var lastError error
9387
88+ mutex .Lock ()
89+ defer mutex .Unlock ()
90+ for _ , path := range corePaths {
9491 if supertokens .IsRunningInTestMode () {
9592 urlsAttemptedForJWKSFetch = append (urlsAttemptedForJWKSFetch , path )
9693 }
@@ -119,18 +116,14 @@ func getJWKS() sessmodels.GetJWKSResult {
119116 returnedFromCache = false
120117 }
121118
122- return jwksResult
119+ return jwksResult . JWKS , nil
123120 }
124121
125122 lastError = jwksError
126123 }
127124
128125 // This means that fetching from all cores failed
129- return sessmodels.GetJWKSResult {
130- JWKS : nil ,
131- Error : lastError ,
132- LastFetched : 0 ,
133- }
126+ return nil , lastError
134127}
135128
136129/**
@@ -145,13 +138,13 @@ func GetCombinedJWKS() (*keyfunc.JWKS, error) {
145138 urlsAttemptedForJWKSFetch = []string {}
146139 }
147140
148- jwksResult := getJWKS ()
141+ jwksResult , err := getJWKS ()
149142
150- if jwksResult . Error != nil {
151- return nil , jwksResult . Error
143+ if err != nil {
144+ return nil , err
152145 }
153146
154- return jwksResult . JWKS , nil
147+ return jwksResult , nil
155148}
156149
157150func MakeRecipeImplementation (querier supertokens.Querier , config sessmodels.TypeNormalisedInput , appInfo supertokens.NormalisedAppinfo ) sessmodels.RecipeInterface {
0 commit comments