Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions client_manager.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package apns2

import (
"sync"
"time"

"container/list"
"crypto/sha1"
"crypto/tls"
"sync"
"time"

"github.com/sideshow/apns2/token"
)

type managerItem struct {
Expand All @@ -31,6 +34,10 @@ type ClientManager struct {
// manager.
Factory func(certificate tls.Certificate) *Client

// FactoryToken is the function which constructs clients if not found in the
// manager when token auth is used
FactoryToken func(token *token.Token) *Client

cache map[[sha1.Size]byte]*list.Element
ll *list.List
mu sync.Mutex
Expand All @@ -48,9 +55,10 @@ type ClientManager struct {
// a Client with default options.
func NewClientManager() *ClientManager {
manager := &ClientManager{
MaxSize: 64,
MaxAge: 10 * time.Minute,
Factory: NewClient,
MaxSize: 64,
MaxAge: 10 * time.Minute,
Factory: NewClient,
FactoryToken: NewTokenClient,
}

manager.initInternals()
Expand All @@ -65,7 +73,14 @@ func (m *ClientManager) Add(client *Client) {
m.mu.Lock()
defer m.mu.Unlock()

key := cacheKey(client.Certificate)
var key [sha1.Size]byte

if client.Token != nil {
key = cacheTokenKey(client.Token)
} else {
key = cacheKey(client.Certificate)
}

now := time.Now()
if ele, hit := m.cache[key]; hit {
item := ele.Value.(*managerItem)
Expand All @@ -88,16 +103,35 @@ func (m *ClientManager) Add(client *Client) {
// the ClientManager's Factory function, store the result in the manager if
// non-nil, and return it.
func (m *ClientManager) Get(certificate tls.Certificate) *Client {
key := cacheKey(certificate)

return m.get(key, func() *Client {
return m.Factory(certificate)
})
}

// Get gets a Client from the manager. If a Client is not found in the manager
// or if a Client has remained in the manager longer than MaxAge, Get will call
// the ClientManager's Factory function, store the result in the manager if
// non-nil, and return it.
func (m *ClientManager) GetByToken(token *token.Token) *Client {
key := cacheTokenKey(token)

return m.get(key, func() *Client {
return m.FactoryToken(token)
})
}

func (m *ClientManager) get(key [sha1.Size]byte, factory func() *Client) *Client {
m.initInternals()
m.mu.Lock()
defer m.mu.Unlock()

key := cacheKey(certificate)
now := time.Now()
if ele, hit := m.cache[key]; hit {
item := ele.Value.(*managerItem)
if m.MaxAge != 0 && item.lastUsed.Before(now.Add(-m.MaxAge)) {
c := m.Factory(certificate)
c := factory()
if c == nil {
return nil
}
Expand All @@ -108,7 +142,7 @@ func (m *ClientManager) Get(certificate tls.Certificate) *Client {
return item.client
}

c := m.Factory(certificate)
c := factory()
if c == nil {
return nil
}
Expand Down Expand Up @@ -160,3 +194,7 @@ func cacheKey(certificate tls.Certificate) [sha1.Size]byte {

return sha1.Sum(data)
}

func cacheTokenKey(token *token.Token) [sha1.Size]byte {
return sha1.Sum([]byte(token.Bearer))
}