diff --git a/authn/token_exchange.go b/authn/token_exchange.go index b4e13b6..7933a3e 100644 --- a/authn/token_exchange.go +++ b/authn/token_exchange.go @@ -22,6 +22,8 @@ type TokenExchanger interface { Exchange(ctx context.Context, r TokenExchangeRequest) (*TokenExchangeResponse, error) } +const defaultCacheTTL = 15 * time.Second + var _ TokenExchanger = &TokenExchangeClient{} // ExchangeClientOpts allows setting custom parameters during construction. @@ -40,6 +42,14 @@ func WithTokenExchangeClientCache(cache cache.Cache) ExchangeClientOpts { } } +// WithMinimumCacheTTL allows setting the minimum amount of time that a cache +// entry must be valid for in order for it to be reused. +func WithMinimumCacheTTL(ttl time.Duration) ExchangeClientOpts { + return func(c *TokenExchangeClient) { + c.minimumTTL = ttl + } +} + func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts) (*TokenExchangeClient, error) { if cfg.Token == "" { return nil, fmt.Errorf("%w: missing required token", ErrMissingConfig) @@ -50,9 +60,10 @@ func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts) } c := &TokenExchangeClient{ - cache: nil, // See below. - cfg: cfg, - singlef: singleflight.Group{}, + cache: nil, // See below. + minimumTTL: defaultCacheTTL, + cfg: cfg, + singlef: singleflight.Group{}, } for _, opt := range opts { @@ -77,14 +88,14 @@ func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts) } return c, nil - } type TokenExchangeClient struct { - cache cache.Cache - cfg TokenExchangeConfig - client *http.Client - singlef singleflight.Group + cache cache.Cache + minimumTTL time.Duration // Minimum time that token must be valid to be reused. + cfg TokenExchangeConfig + client *http.Client + singlef singleflight.Group } type TokenExchangeRequest struct { @@ -207,8 +218,6 @@ func (c *TokenExchangeClient) getCache(ctx context.Context, key string) (string, } func (c *TokenExchangeClient) setCache(ctx context.Context, token string, key string) error { - const cacheLeeway = 15 * time.Second - parsed, err := jwt.ParseSigned(token) if err != nil { return fmt.Errorf("failed to parse token: %v", err) @@ -219,7 +228,7 @@ func (c *TokenExchangeClient) setCache(ctx context.Context, token string, key st return fmt.Errorf("failed to extract claims from the token: %v", err) } - return c.cache.Set(ctx, key, []byte(token), time.Until(claims.Expiry.Time())-cacheLeeway) + return c.cache.Set(ctx, key, []byte(token), time.Until(claims.Expiry.Time())-c.minimumTTL) } var _ TokenExchanger = StaticTokenExchanger{} diff --git a/authn/token_exchange_test.go b/authn/token_exchange_test.go index c91e765..9405d25 100644 --- a/authn/token_exchange_test.go +++ b/authn/token_exchange_test.go @@ -217,6 +217,28 @@ func Test_TokenExchangeClient_Exchange(t *testing.T) { }) } +func Test_WithMinimumCacheTTL(t *testing.T) { + cfg := TokenExchangeConfig{ + Token: "some-token", + TokenExchangeURL: "http://localhost", + } + + t.Run("not using WithMinimumCacheTTL should use the default", func(t *testing.T) { + client, err := NewTokenExchangeClient(cfg) + require.NoError(t, err) + require.NotNil(t, client) + assert.Equal(t, defaultCacheTTL, client.minimumTTL) + }) + + t.Run("using WithMinimumCacheTTL should modify the value", func(t *testing.T) { + customTTL := 42 * time.Second + client, err := NewTokenExchangeClient(cfg, WithMinimumCacheTTL(customTTL)) + require.NoError(t, err) + require.NotNil(t, client) + assert.Equal(t, customTTL, client.minimumTTL) + }) +} + func signAccessToken(t *testing.T, expiresIn time.Duration) string { signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.HS256,