Skip to content

Commit 5455052

Browse files
authored
Add WithCache option to for NewKeyRetriever (#166)
It might be worth offering the option to set a custom cache, possibly with different policies or with a different caching store. This is almost supported, as cache.Cache is an interface, so it's possible to reimplement it. This is done using a DefaultKeyRetrieverOption function to pass to NewKeyRetriever that can replace the cache field `c`. The problem is that the field is not exported, so it's not possible to implement the option outside of the package. Add that option here. It needs to take into consideration the fact that cache.NewLocalCache ends up starting a goroutine to do clean up tasks. While the goroutine should stop because the cache has been replaced, it's better to not start it in the first place. Signed-off-by: Marcelo E. Magallon <marcelo.magallon@grafana.com>
1 parent 2f4a826 commit 5455052

File tree

4 files changed

+187
-10
lines changed

4 files changed

+187
-10
lines changed

authn/jwks.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,43 @@ func WithHTTPClientKeyRetrieverOpt(client *http.Client) DefaultKeyRetrieverOptio
2727
}
2828
}
2929

30+
func WithKeyRetrieverCache(cache cache.Cache) DefaultKeyRetrieverOption {
31+
return func(c *DefaultKeyRetriever) {
32+
c.c = cache
33+
}
34+
}
35+
3036
const (
3137
cacheTTL = 10 * time.Minute
3238
cacheCleanupInterval = 10 * time.Minute
3339
)
3440

3541
func NewKeyRetriever(cfg KeyRetrieverConfig, opt ...DefaultKeyRetrieverOption) *DefaultKeyRetriever {
3642
s := &DefaultKeyRetriever{
37-
cfg: cfg,
38-
c: cache.NewLocalCache(cache.Config{
39-
Expiry: cacheTTL,
40-
CleanupInterval: cacheCleanupInterval,
41-
}),
43+
cfg: cfg,
44+
c: nil, // See below.
4245
client: http.DefaultClient,
4346
s: &singleflight.Group{},
4447
}
4548

4649
for _, o := range opt {
4750
o(s)
4851
}
52+
53+
// If the options did not set the cache, create a new local cache.
54+
//
55+
// This has to be done this way because the cache that is created by
56+
// the cache.NewLocalCache function spawns a goroutine that cannot be
57+
// trivially stopped. It is set up to stop when the object is garbage
58+
// collected, but in the general case, the calling code will not have
59+
// control over that.
60+
if s.c == nil {
61+
s.c = cache.NewLocalCache(cache.Config{
62+
Expiry: cacheTTL,
63+
CleanupInterval: cacheCleanupInterval,
64+
})
65+
}
66+
4967
return s
5068
}
5169

authn/jwks_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@ package authn
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"net/http"
89
"net/http/httptest"
10+
"sync"
911
"testing"
12+
"time"
1013

1114
"github.com/go-jose/go-jose/v3"
15+
"github.com/grafana/authlib/cache"
1216
"github.com/stretchr/testify/assert"
1317
"github.com/stretchr/testify/require"
1418
)
@@ -36,6 +40,9 @@ func TestDefaultKeyRetriever_Get(t *testing.T) {
3640
SigningKeysURL: server.URL,
3741
})
3842

43+
require.NotNil(t, service)
44+
require.NotNil(t, service.c)
45+
3946
t.Run("should fetched key if not cached", func(t *testing.T) {
4047
key, err := service.Get(context.Background(), firstKeyID)
4148
require.NoError(t, err)
@@ -66,3 +73,106 @@ func TestDefaultKeyRetriever_Get(t *testing.T) {
6673
}
6774
})
6875
}
76+
77+
func TestWithKeyRetrieverCache(t *testing.T) {
78+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
79+
w.WriteHeader(http.StatusOK)
80+
_, _ = w.Write(keys())
81+
}))
82+
83+
t.Cleanup(func() {
84+
server.CloseClientConnections()
85+
server.Close()
86+
})
87+
88+
tc := &testCache{data: make(map[string][]byte)}
89+
90+
// Create a new retriever with the test cache.
91+
service := NewKeyRetriever(KeyRetrieverConfig{
92+
SigningKeysURL: server.URL,
93+
}, WithKeyRetrieverCache(tc))
94+
95+
require.NotNil(t, service, "service should not be nil")
96+
require.NotNil(t, service.c, "there should be a cache")
97+
require.Equal(t, tc, service.c, "the cache should be the one passed in the options")
98+
99+
// Validate that the key is not in the cache
100+
data, err := tc.Get(context.Background(), firstKeyID)
101+
require.Error(t, err, "the initial cache should be empty")
102+
require.Nil(t, data, "the initial cache should be empty")
103+
104+
// The cache is empty, so the implementation should fetch the key.
105+
key, err := service.Get(context.Background(), firstKeyID)
106+
require.NoError(t, err, "getting a key not present in the cache should not return an error")
107+
require.NotNil(t, key, "Get should return a key")
108+
assert.Equal(t, firstKeyID, key.KeyID, "the key should match the one requested")
109+
110+
// If the implementation called the cache, the data should be there now.
111+
data, err = tc.Get(context.Background(), firstKeyID)
112+
require.NoError(t, err, "the cache should have the key now")
113+
require.NotNil(t, data, "the cache should have the key now")
114+
115+
// Decode the data to validate that it matches the key. We know the
116+
// entries in the cache are JSON-encoded keys.
117+
var jwk jose.JSONWebKey
118+
require.NoError(t, json.Unmarshal(data, &jwk), "the data should be valid JSON")
119+
require.Equal(t, firstKeyID, jwk.KeyID, "the key id should match the one requested")
120+
121+
// Remove the key from the cache; the implementation should still return the key.
122+
err = tc.Delete(context.Background(), firstKeyID)
123+
require.NoError(t, err, "deleting the key from the cache should not return an error")
124+
125+
key, err = service.Get(context.Background(), firstKeyID)
126+
require.NoError(t, err, "getting a key not present in the cache should not return an error")
127+
require.NotNil(t, key, "Get should return a key")
128+
assert.Equal(t, firstKeyID, key.KeyID, "the key should match the one requested")
129+
130+
// Retrieve an invalid key; the implementation should return an error.
131+
key, err = service.Get(context.Background(), "invalid")
132+
require.ErrorIs(t, err, ErrInvalidSigningKey)
133+
require.Nil(t, key)
134+
135+
// The implementation adds invalid keys to the cache to prevent re-fetching.
136+
data, err = tc.Get(context.Background(), "invalid")
137+
require.NoError(t, err, "the cache should have the invalid key now")
138+
require.NotNil(t, data, "the cache should have the invalid key now")
139+
require.Empty(t, data, "the cache should have the invalid key now")
140+
}
141+
142+
// testCache implements the Cache interface for testing purposes.
143+
type testCache struct {
144+
mu sync.Mutex
145+
data map[string][]byte
146+
}
147+
148+
var _ cache.Cache = (*testCache)(nil)
149+
150+
func (cache *testCache) Get(ctx context.Context, key string) ([]byte, error) {
151+
cache.mu.Lock()
152+
defer cache.mu.Unlock()
153+
154+
item, ok := cache.data[key]
155+
if !ok {
156+
return nil, errors.New("not found")
157+
}
158+
159+
return item, nil
160+
}
161+
162+
func (cache *testCache) Set(ctx context.Context, key string, value []byte, expire time.Duration) error {
163+
cache.mu.Lock()
164+
defer cache.mu.Unlock()
165+
166+
cache.data[key] = value
167+
168+
return nil
169+
}
170+
171+
func (cache *testCache) Delete(ctx context.Context, key string) error {
172+
cache.mu.Lock()
173+
defer cache.mu.Unlock()
174+
175+
delete(cache.data, key)
176+
177+
return nil
178+
}

authn/token_exchange.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ func WithHTTPClient(client *http.Client) ExchangeClientOpts {
3434
}
3535
}
3636

37+
func WithTokenExchangeClientCache(cache cache.Cache) ExchangeClientOpts {
38+
return func(c *TokenExchangeClient) {
39+
c.cache = cache
40+
}
41+
}
42+
3743
func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts) (*TokenExchangeClient, error) {
3844
if cfg.Token == "" {
3945
return nil, fmt.Errorf("%w: missing required token", ErrMissingConfig)
@@ -44,9 +50,7 @@ func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts)
4450
}
4551

4652
c := &TokenExchangeClient{
47-
cache: cache.NewLocalCache(cache.Config{
48-
CleanupInterval: 5 * time.Minute,
49-
}),
53+
cache: nil, // See below.
5054
cfg: cfg,
5155
singlef: singleflight.Group{},
5256
}
@@ -59,6 +63,19 @@ func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts)
5963
c.client = httpclient.New()
6064
}
6165

66+
// If the options did not set the cache, create a new local cache.
67+
//
68+
// This has to be done this way because the cache that is created by
69+
// the cache.NewLocalCache function spawns a goroutine that cannot be
70+
// trivially stopped. It is set up to stop when the object is garbage
71+
// collected, but in the general case, the calling code will not have
72+
// control over that.
73+
if c.cache == nil {
74+
c.cache = cache.NewLocalCache(cache.Config{
75+
CleanupInterval: 5 * time.Minute,
76+
})
77+
}
78+
6279
return c, nil
6380

6481
}

authn/token_exchange_test.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ func TestNewTokenExchangeClient(t *testing.T) {
4343

4444
func Test_TokenExchangeClient_Exchange(t *testing.T) {
4545
expiresIn := 10 * time.Minute
46-
setup := func(srv *httptest.Server) *TokenExchangeClient {
46+
setup := func(srv *httptest.Server, opts ...ExchangeClientOpts) *TokenExchangeClient {
4747
c, err := NewTokenExchangeClient(TokenExchangeConfig{
4848
Token: "some-token",
4949
TokenExchangeURL: srv.URL,
50-
})
50+
}, opts...)
5151
require.NoError(t, err)
5252
return c
5353
}
@@ -183,6 +183,38 @@ func Test_TokenExchangeClient_Exchange(t *testing.T) {
183183
expectedExpiry := time.Now().Add(time.Duration(expiresIn) * time.Second)
184184
require.InDelta(t, expectedExpiry.Unix(), claims.Expiry.Time().Unix(), 1)
185185
})
186+
187+
t.Run("should use an alternate cache if provided", func(t *testing.T) {
188+
testcache := &testCache{data: make(map[string][]byte)}
189+
190+
var calls int
191+
c := setup(httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
192+
calls++
193+
require.Equal(t, r.Header.Get("Authorization"), "Bearer some-token")
194+
w.WriteHeader(http.StatusOK)
195+
_, _ = w.Write([]byte(`{"data": {"token": "` + signAccessToken(t, expiresIn) + `"}}`))
196+
bytes.NewBuffer([]byte(`{}`))
197+
json.NewEncoder(&bytes.Buffer{})
198+
})), WithTokenExchangeClientCache(testcache))
199+
200+
tokenToBeExchanged := signAccessToken(t, expiresIn)
201+
202+
res1, err := c.Exchange(context.Background(), TokenExchangeRequest{Namespace: "*", Audiences: []string{"some-service"}, SubjectToken: tokenToBeExchanged})
203+
assert.NoError(t, err)
204+
assert.NotNil(t, res1)
205+
require.Equal(t, 1, calls)
206+
require.Len(t, testcache.data, 1)
207+
208+
// same namespace and audiences should load token from cache
209+
res2, err := c.Exchange(context.Background(), TokenExchangeRequest{Namespace: "*", Audiences: []string{"some-service"}, SubjectToken: tokenToBeExchanged})
210+
assert.NoError(t, err)
211+
assert.NotNil(t, res2)
212+
require.Equal(t, 1, calls)
213+
require.Len(t, testcache.data, 1)
214+
require.Equal(t, res1, res2)
215+
216+
// This is only testing that the cache is used, so we do not repeat the other cases here.
217+
})
186218
}
187219

188220
func signAccessToken(t *testing.T, expiresIn time.Duration) string {

0 commit comments

Comments
 (0)