diff --git a/client_manager.go b/client_manager.go index bb4bdf90..8ff70808 100644 --- a/client_manager.go +++ b/client_manager.go @@ -6,6 +6,8 @@ import ( "crypto/tls" "sync" "time" + + "github.com/sideshow/apns2/token" ) type managerItem struct { @@ -31,6 +33,10 @@ type ClientManager struct { // manager. Factory func(certificate tls.Certificate) *Client + // FactoryToken is the function which constructs clients if not found in the + // manager when token auth is used + FactoryToken func(token *token.Token) *Client + cache map[[sha1.Size]byte]*list.Element ll *list.List mu sync.Mutex @@ -48,9 +54,10 @@ type ClientManager struct { // a Client with default options. func NewClientManager() *ClientManager { manager := &ClientManager{ - MaxSize: 64, - MaxAge: 10 * time.Minute, - Factory: NewClient, + MaxSize: 64, + MaxAge: 10 * time.Minute, + Factory: NewClient, + FactoryToken: NewTokenClient, } manager.initInternals() @@ -65,7 +72,14 @@ func (m *ClientManager) Add(client *Client) { m.mu.Lock() defer m.mu.Unlock() - key := cacheKey(client.Certificate) + var key [sha1.Size]byte + + if client.Token != nil { + key = cacheTokenKey(client.Token) + } else { + key = cacheKey(client.Certificate) + } + now := time.Now() if ele, hit := m.cache[key]; hit { item := ele.Value.(*managerItem) @@ -88,16 +102,35 @@ func (m *ClientManager) Add(client *Client) { // the ClientManager's Factory function, store the result in the manager if // non-nil, and return it. func (m *ClientManager) Get(certificate tls.Certificate) *Client { + key := cacheKey(certificate) + + return m.get(key, func() *Client { + return m.Factory(certificate) + }) +} + +// Get gets a Client from the manager. If a Client is not found in the manager +// or if a Client has remained in the manager longer than MaxAge, Get will call +// the ClientManager's Factory function, store the result in the manager if +// non-nil, and return it. +func (m *ClientManager) GetByToken(token *token.Token) *Client { + key := cacheTokenKey(token) + + return m.get(key, func() *Client { + return m.FactoryToken(token) + }) +} + +func (m *ClientManager) get(key [sha1.Size]byte, factory func() *Client) *Client { m.initInternals() m.mu.Lock() defer m.mu.Unlock() - key := cacheKey(certificate) now := time.Now() if ele, hit := m.cache[key]; hit { item := ele.Value.(*managerItem) if m.MaxAge != 0 && item.lastUsed.Before(now.Add(-m.MaxAge)) { - c := m.Factory(certificate) + c := factory() if c == nil { return nil } @@ -108,7 +141,7 @@ func (m *ClientManager) Get(certificate tls.Certificate) *Client { return item.client } - c := m.Factory(certificate) + c := factory() if c == nil { return nil } @@ -160,3 +193,7 @@ func cacheKey(certificate tls.Certificate) [sha1.Size]byte { return sha1.Sum(data) } + +func cacheTokenKey(token *token.Token) [sha1.Size]byte { + return sha1.Sum([]byte(token.Bearer)) +} diff --git a/client_manager_test.go b/client_manager_test.go index c4927088..ab07e6d6 100644 --- a/client_manager_test.go +++ b/client_manager_test.go @@ -10,6 +10,7 @@ import ( "github.com/sideshow/apns2" "github.com/sideshow/apns2/certificate" + "github.com/sideshow/apns2/token" "github.com/stretchr/testify/assert" ) @@ -147,3 +148,67 @@ func TestClientManagerAddTwice(t *testing.T) { manager.Add(apns2.NewClient(mockCert())) assert.Equal(t, 1, manager.Len()) } + +func TestClientManagerAddTokenClientWithoutNew(t *testing.T) { + fn := func(token *token.Token) *apns2.Client { + t.Fatal("factory should not have been called") + return nil + } + + manager := apns2.NewClientManager() + manager.FactoryToken = fn + token := mockToken() + manager.Add(apns2.NewTokenClient(token)) + manager.GetByToken(token) +} + +func TestClientManagerAddTokenClientWithNew(t *testing.T) { + manager := apns2.NewClientManager() + + t1 := mockToken() + _, err := t1.Generate() + assert.NoError(t, err) + + t2 := mockToken() + _, err = t2.Generate() + assert.NoError(t, err) + + manager.Add(apns2.NewTokenClient(t1)) + manager.Add(apns2.NewTokenClient(t2)) + assert.Equal(t, 2, manager.Len()) +} + +func TestClientManagerGetByTokenWithoutNew(t *testing.T) { + manager := apns2.NewClientManager() + + token := mockToken() + c1 := manager.GetByToken(token) + c2 := manager.GetByToken(token) + v1 := reflect.ValueOf(c1) + v2 := reflect.ValueOf(c2) + assert.NotNil(t, c1) + assert.Equal(t, v1.Pointer(), v2.Pointer()) + assert.Equal(t, 1, manager.Len()) +} + +func TestClientManagerGetByTokenWithNew(t *testing.T) { + manager := apns2.NewClientManager() + + t1 := mockToken() + _, err := t1.Generate() + assert.NoError(t, err) + + t2 := mockToken() + _, err = t2.Generate() + assert.NoError(t, err) + + c1 := manager.GetByToken(t1) + c2 := manager.GetByToken(t2) + + v1 := reflect.ValueOf(c1) + v2 := reflect.ValueOf(c2) + assert.NotNil(t, c1) + assert.NotNil(t, c2) + assert.NotEqual(t, v1.Pointer(), v2.Pointer()) + assert.Equal(t, 2, manager.Len()) +}