diff --git a/.travis.yml b/.travis.yml index 18c443a..0987969 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,14 +1,21 @@ +sudo: false language: go go: - - 1.3 -services: - - redis-server + - 1.4.2 before_script: - go get github.com/onsi/ginkgo - go get github.com/onsi/gomega - - go get code.google.com/p/go.tools/cmd/cover + - go get golang.org/x/tools/cmd/cover - go install github.com/onsi/ginkgo/ginkgo -script: ginkgo -r --skipMeasurements --cover --trace +script: ginkgo -r --skipMeasurements --cover --trace --race env: global: - PATH=$HOME/gopath/bin:$PATH + +notifications: + email: false + +branches: + only: + - master + - develop diff --git a/apns_suite_test.go b/apns_suite_test.go index b0bcca4..7d40d9f 100644 --- a/apns_suite_test.go +++ b/apns_suite_test.go @@ -5,8 +5,45 @@ import ( . "github.com/onsi/gomega" "testing" + "time" ) +type mockConn struct { + connect func() error + read func([]byte) (int, error) + setReadDeadline func(time.Time) error +} + +func (m *mockConn) Connect() error { + if m.connect != nil { + return m.connect() + } + + return nil +} + +func (m *mockConn) Read(b []byte) (int, error) { + if m.read != nil { + return m.read(b) + } + return 0, nil +} + +func (m *mockConn) Write([]byte) (int, error) { + return 0, nil +} + +func (m *mockConn) Close() error { + return nil +} + +func (m *mockConn) SetReadDeadline(t time.Time) error { + if m.setReadDeadline != nil { + return m.setReadDeadline(t) + } + return nil +} + func TestApns(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Apns Suite") diff --git a/buffer.go b/buffer.go new file mode 100644 index 0000000..23ba78f --- /dev/null +++ b/buffer.go @@ -0,0 +1,79 @@ +package apns + +import ( + "container/list" + "sync" +) + +// circular buffer of sent messages +// to retry if connection is dropped +type buffer struct { + size int + m sync.Mutex + *list.List +} + +func newBuffer(size int) *buffer { + return &buffer{size, sync.Mutex{}, list.New()} +} + +func (b *buffer) Add(v interface{}) *list.Element { + b.m.Lock() + defer b.m.Unlock() + + e := b.PushBack(v) + + if b.Len() > b.size { + b.Remove(b.Front()) + } + + return e +} + +// NotificationResult associates an error from Apple to a notification. +type NotificationResult struct { + Notif Notification + Err Error +} + +func (s NotificationResult) Error() string { + return s.Err.Error() +} + +func (b *buffer) FindFailedNotification(e Error) NotificationResult { + for cursor := b.Back(); cursor != nil; cursor = cursor.Prev() { + // Get serialized notification + n, _ := cursor.Value.(Notification) + + // If the notification, move cursor after the trouble notification + if n.Identifier == e.Identifier { + return NotificationResult{n, e} + } + } + return NotificationResult{Notification{}, e} +} + +// RequeueableNotifications returns good notifications sent after an error. +func (b *buffer) RequeueableNotifications(identifier uint32) []Notification { + notifs := []Notification{} + + // Walk back to last known good notification and return the slice + var e *list.Element + for e = b.Front(); e != nil; e = e.Next() { + if n, ok := e.Value.(Notification); ok && n.Identifier == identifier { + break + } + } + + // Start right after errored ID and get the rest of the list + for e = e.Next(); e != nil; e = e.Next() { + n, ok := e.Value.(Notification) + if !ok { + continue + } + + notifs = append(notifs, n) + } + + return notifs +} diff --git a/buffer_test.go b/buffer_test.go new file mode 100644 index 0000000..d6260c4 --- /dev/null +++ b/buffer_test.go @@ -0,0 +1,11 @@ +package apns + +import ( + . "github.com/onsi/ginkgo" + // . "github.com/onsi/gomega" +) + +var _ = Describe("Session", func() { + Describe("RequeueableNotifications", func() { + }) +}) diff --git a/client.go b/client.go index de7ab1a..52c34d3 100644 --- a/client.go +++ b/client.go @@ -1,219 +1,101 @@ package apns import ( - "container/list" "crypto/tls" - "io" - "log" + "sync" "time" ) -type buffer struct { - size int - *list.List -} - -func newBuffer(size int) *buffer { - return &buffer{size, list.New()} -} - -func (b *buffer) Add(v interface{}) *list.Element { - e := b.PushBack(v) - - if b.Len() > b.size { - b.Remove(b.Front()) - } - - return e -} - +// Client creates a session with Apple and handles reconnection. type Client struct { - Conn *Conn - FailedNotifs chan NotificationResult + conn Conn - notifs chan Notification - id uint32 + sess Session + sessm sync.Mutex } -func newClientWithConn(gw string, conn Conn) Client { - c := Client{ - Conn: &conn, - FailedNotifs: make(chan NotificationResult), - id: uint32(1), - notifs: make(chan Notification), - } +func newClientWithConn(conn Conn) (Client, error) { + c := Client{conn: conn} - go c.runLoop() + sess := newSession(conn) + err := sess.Connect() + if err != nil { + return c, err + } - return c + return Client{conn, sess, sync.Mutex{}}, nil } -func NewClientWithCert(gw string, cert tls.Certificate) Client { +// NewClientWithCert creates a client of the Apple gateway given a certificate. +func NewClientWithCert(gw string, cert tls.Certificate) (Client, error) { conn := NewConnWithCert(gw, cert) - - return newClientWithConn(gw, conn) + return newClientWithConn(conn) } +// NewClient is a helper that accepts a certificate/key pair. func NewClient(gw string, cert string, key string) (Client, error) { conn, err := NewConn(gw, cert, key) if err != nil { return Client{}, err } - return newClientWithConn(gw, conn), nil + return newClientWithConn(conn) } +// NewClientWithFiles is a helper that loads a certificate/key from files. func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, error) { conn, err := NewConnWithFiles(gw, certFile, keyFile) if err != nil { return Client{}, err } - return newClientWithConn(gw, conn), nil + return newClientWithConn(conn) } +// Send a notification, handling disconnections. func (c *Client) Send(n Notification) error { - c.notifs <- n - return nil -} - -func (c *Client) reportFailedPush(v interface{}, err *Error) { - failedNotif, ok := v.(Notification) - if !ok || v == nil { - return + if c.sess.Disconnected() { + c.reconnectAndRequeue() } - select { - case c.FailedNotifs <- NotificationResult{Notif: failedNotif, Err: *err}: - default: - } + return c.sess.Send(n) } -func (c *Client) requeue(cursor *list.Element) { - // If `cursor` is not nil, this means there are notifications that - // need to be delivered (or redelivered) - for ; cursor != nil; cursor = cursor.Next() { - if n, ok := cursor.Value.(Notification); ok { - go func() { c.notifs <- n }() - } - } -} - -func (c *Client) handleError(err *Error, buffer *buffer) *list.Element { - cursor := buffer.Back() +func (c *Client) reconnectAndRequeue() { + c.sessm.Lock() + defer c.sessm.Unlock() - for cursor != nil { - // Get notification - n, _ := cursor.Value.(Notification) + // Pull off undelivered notifications + notifs := c.sess.RequeueableNotifications() - // If the notification, move cursor after the trouble notification - if n.Identifier == err.Identifier { - go c.reportFailedPush(cursor.Value, err) + c.reconnect() - next := cursor.Next() - - buffer.Remove(cursor) - return next - } - - cursor = cursor.Prev() + for _, n := range notifs { + // TODO handle error from sending + c.sess.Send(n) } - - return cursor } -func (c *Client) runLoop() { - sent := newBuffer(50) - cursor := sent.Front() +func (c *Client) reconnect() { + c.sess = nil - // APNS connection - for { - err := c.Conn.Connect() + for c.sess == nil { + sess := newSession(c.conn) + + err := sess.Connect() if err != nil { - // TODO Probably want to exponentially backoff... + // TODO retry policy + // TODO connect error channel + // Keep trying to connect time.Sleep(1 * time.Second) continue } - // Start reading errors from APNS - errs := readErrs(c.Conn) - - c.requeue(cursor) - - // Connection open, listen for notifs and errors - for { - var err error - var n Notification - - // Check for notifications or errors. There is a chance we'll send notifications - // if we already have an error since `select` will "pseudorandomly" choose a - // ready channels. It turns out to be fine because the connection will already - // be closed and it'll requeue. We could check before we get to this select - // block, but it doesn't seem worth the extra code and complexity. - select { - case err = <-errs: - case n = <-c.notifs: - } - - // If there is an error we understand, find the notification that failed, - // move the cursor right after it. - if nErr, ok := err.(*Error); ok { - cursor = c.handleError(nErr, sent) - break - } - - if err != nil { - break - } - - // Add to list - cursor = sent.Add(n) - - // Set identifier if not specified - if n.Identifier == 0 { - n.Identifier = c.id - c.id++ - } else if c.id < n.Identifier { - c.id = n.Identifier + 1 - } - - b, err := n.ToBinary() - if err != nil { - // TODO - continue - } - - _, err = c.Conn.Write(b) - - if err == io.EOF { - log.Println("EOF trying to write notification") - break - } - - if err != nil { - log.Println("err writing to apns", err.Error()) - break - } - - cursor = cursor.Next() - } + c.sess = sess } } -func readErrs(c *Conn) chan error { - errs := make(chan error) - - go func() { - p := make([]byte, 6, 6) - _, err := c.Read(p) - if err != nil { - errs <- err - return - } - - e := NewError(p) - errs <- &e - }() - - return errs +// newSession for altering in tests +var newSession = func(c Conn) Session { + return NewSession(c) } diff --git a/client_test.go b/client_test.go index c9dfd47..5363eb4 100644 --- a/client_test.go +++ b/client_test.go @@ -1,38 +1,91 @@ -package apns_test +package apns import ( - "bytes" - "encoding/binary" + "errors" "io/ioutil" "os" - "time" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - "github.com/timehop/apns" + "github.com/timehop/tcptest" ) +type mockSession struct { + sendCB func(n Notification) error + requeueNotifs []Notification + disconnectedState bool +} + +func (m *mockSession) Send(n Notification) error { + if m.sendCB == nil { + return nil + } + + return m.sendCB(n) +} + +func (m *mockSession) Connect() error { + return nil +} + +func (m *mockSession) RequeueableNotifications() []Notification { + if len(m.requeueNotifs) == 0 { + return []Notification{} + } + + return m.requeueNotifs +} + +func (m *mockSession) Disconnect() { + m.disconnectedState = true +} + +func (m *mockSession) Disconnected() bool { + return m.disconnectedState +} + +type badConnMockSession struct { + *mockSession +} + +func (m badConnMockSession) Connect() error { + return errors.New("whatev") +} + var _ = Describe("Client", func() { - Describe(".NewConn", func() { + BeforeEach(func() { + newSession = func(_ Conn) Session { return &mockSession{} } + }) + + Describe(".NewClient", func() { Context("bad cert/key pair", func() { It("should error out", func() { - _, err := apns.NewClient(apns.ProductionGateway, "missing", "missing_also") + _, err := NewClient(ProductionGateway, "missing", "missing_also") Expect(err).NotTo(BeNil()) }) }) Context("valid cert/key pair", func() { It("should create a valid client", func() { - c, err := apns.NewClient(apns.ProductionGateway, DummyCert, DummyKey) + _, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) - Expect(c.Conn).NotTo(BeNil()) + }) + }) + + Context("bad connection", func() { + It("should error out", func() { + newSession = func(_ Conn) Session { return badConnMockSession{} } + + _, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).NotTo(BeNil()) }) }) }) - Describe(".NewConnWithFiles", func() { + Describe(".NewClientWithFiles", func() { Context("missing cert/key pair", func() { It("should error out", func() { - _, err := apns.NewClientWithFiles(apns.ProductionGateway, "missing", "missing_also") + _, err := NewClientWithFiles(ProductionGateway, "missing", "missing_also") Expect(err).NotTo(BeNil()) }) }) @@ -42,11 +95,11 @@ var _ = Describe("Client", func() { BeforeEach(func() { certFile, _ = ioutil.TempFile("", "cert.pem") - certFile.Write([]byte(DummyCert)) + certFile.Write([]byte(tcptest.LocalhostCert)) certFile.Close() keyFile, _ = ioutil.TempFile("", "key.pem") - keyFile.Write([]byte(DummyKey)) + keyFile.Write([]byte(tcptest.LocalhostKey)) keyFile.Close() }) @@ -61,324 +114,92 @@ var _ = Describe("Client", func() { }) It("should create a valid client", func() { - c, err := apns.NewClientWithFiles(apns.ProductionGateway, certFile.Name(), keyFile.Name()) + _, err := NewClientWithFiles(ProductionGateway, certFile.Name(), keyFile.Name()) Expect(err).To(BeNil()) - Expect(c.Conn).NotTo(BeNil()) }) }) }) - Describe("#Send", func() { - Context("simple write", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true + Describe("Send", func() { + Context("connected", func() { + Context("valid push", func() { + It("should not return an error", func() { + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) - Expect(c.Send(apns.Notification{})).To(BeNil()) - - close(mockDone) - close(d) + err = c.Send(Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) }) }) - }) - Context("simple write with buffer", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - for i := 0; i < 54; i++ { - Expect(c.Send(apns.Notification{})).To(BeNil()) + Context("invalid notification", func() { + It("should return an error", func() { + newSession = func(_ Conn) Session { + return &mockSession{ + sendCB: func(_ Notification) error { + return errors.New("") + }, + } } - close(mockDone) - close(d) - }) - }) - }) - - Context("multiple write", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - serverAction{action: readAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - Expect(c.Send(apns.Notification{})).To(BeNil()) - Expect(c.Send(apns.Notification{})).To(BeNil()) - - close(mockDone) - close(d) - }) - }) - }) - - Context("bad push", func() { - n := apns.Notification{Identifier: 9, ID: "some_rando"} - nb, _ := n.ToBinary() - nbcb := make([]byte, len(nb)) - - errPayload := bytes.NewBuffer([]byte{}) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint32(9)) - - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - serverAction{action: readAction, data: nbcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(nb)) - }}, - - // Bad push results in a close - serverAction{action: writeAction, data: errPayload.Bytes()}, - serverAction{action: closeAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - go func() { - n := <-c.FailedNotifs - - Expect(n.Notif.Identifier).To(Equal(uint32(9))) - Expect(n.Notif.ID).To(Equal("some_rando")) - - close(mockDone) - close(d) - }() - - Expect(c.Send(n)).To(BeNil()) - }) - }) - }) - - Context("closed, reconnect", func() { - done := make(chan bool) - - n1 := apns.Notification{Identifier: 1} - n1b, _ := n1.ToBinary() - n1bcb := make([]byte, len(n1b)) - - errPayload := bytes.NewBuffer([]byte{}) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint32(2)) - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - - as := [][]serverAction{ - []serverAction{ - // Write error - serverAction{action: writeAction, data: errPayload.Bytes(), cb: func(a serverAction) { - done <- true - }}, - - // Close on error - serverAction{action: closeAction, cb: func(a serverAction) { - }}, - }, - []serverAction{ - // Reconnect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Reconnected - }}, - - // Read first good notification - serverAction{action: readAction, data: n1bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n1b)) - - close(mockDone) - close(d) - }}, - }, - } - - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) - <-done - time.Sleep(5 * time.Millisecond) - - // Good - Expect(c.Send(n1)).To(BeNil()) + err = c.Send(Notification{DeviceToken: "lol"}) + Expect(err).NotTo(BeNil()) }) }) }) - Context("good, close, good, requeue of last good", func() { - closed := make(chan bool) - - n1 := apns.Notification{Identifier: 1} - n2 := apns.Notification{Identifier: 2} - - n1b, _ := n1.ToBinary() - n2b, _ := n2.ToBinary() - - n1bcb := make([]byte, len(n1b)) - n2bcb := make([]byte, len(n2b)) - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - as := [][]serverAction{ - []serverAction{ - // Connect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Handshake - }}, - - // Read first good notification - serverAction{action: readAction, data: n1bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n1b)) - }}, - - // Close on error - serverAction{action: closeAction, cb: func(a serverAction) { - closed <- true - }}, - }, - []serverAction{ - // Reconnect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Reconnected - }}, - - // Requeue - serverAction{action: readAction, data: n2bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n2b)) - - close(mockDone) - close(d) - }}, - }, + Context("disconnected", func() { + It("should reconnect", func() { + newSessCount := 0 + newSession = func(_ Conn) Session { + newSessCount++ + return &mockSession{} } - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) - // Good - Expect(c.Send(n1)).To(BeNil()) + c.sess.Disconnect() - <-closed - time.Sleep(5 * time.Millisecond) + err = c.Send(Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) - // Good - Expect(c.Send(n2)).To(BeNil()) - }) + Expect(newSessCount).To(Equal(2)) }) }) - Context("good, bad, good, requeue of last good", func() { - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - - n1 := apns.Notification{Identifier: 1} - n2 := apns.Notification{Identifier: 2} - n3 := apns.Notification{Identifier: 3} - - n1b, _ := n1.ToBinary() - n2b, _ := n2.ToBinary() - n3b, _ := n3.ToBinary() - - n1bcb := make([]byte, len(n1b)) - n2bcb := make([]byte, len(n2b)) - n3bcb := make([]byte, len(n3b)) - - errPayload := bytes.NewBuffer([]byte{}) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint32(2)) - - as := [][]serverAction{ - []serverAction{ - // Connect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Handshake - }}, - - // Read first good notification - serverAction{action: readAction, data: n1bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n1b)) - }}, - - // Read bad notification - serverAction{action: readAction, data: n2bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n2b)) - }}, - - // Read second good notification - serverAction{action: readAction, data: n3bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n3b)) - }}, - - // Write error - serverAction{action: writeAction, data: errPayload.Bytes(), cb: func(a serverAction) { - }}, - - // Close on error - serverAction{action: closeAction, cb: func(a serverAction) { - }}, + It("should reconnect and requeue", func() { + newSessCount := 0 + sendCount := 0 + + newSession = func(_ Conn) Session { + newSessCount++ + return &mockSession{ + requeueNotifs: []Notification{ + Notification{}, + Notification{}, + Notification{}, }, - []serverAction{ - // Reconnect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Reconnected - }}, - - // Requeue - serverAction{action: readAction, data: n3bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n3b)) - - close(mockDone) - close(d) - }}, + sendCB: func(_ Notification) error { + sendCount++ + return nil }, } + } - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) - // Good - Expect(c.Send(n1)).To(BeNil()) + c.sess.Disconnect() - // Bad - Expect(c.Send(n2)).To(BeNil()) + err = c.Send(Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) - // Good - Expect(c.Send(n3)).To(BeNil()) - }) - }) + Expect(newSessCount).To(Equal(2)) + Expect(sendCount).To(Equal(4)) }) }) }) diff --git a/conn.go b/conn.go index d3aa712..d76577b 100644 --- a/conn.go +++ b/conn.go @@ -2,94 +2,109 @@ package apns import ( "crypto/tls" + "io" "net" "strings" + "time" ) const ( + // ProductionGateway is the host for Apple Push Notification server. ProductionGateway = "gateway.push.apple.com:2195" - SandboxGateway = "gateway.sandbox.push.apple.com:2195" + // SandboxGateway is Apple's gateway for development. + SandboxGateway = "gateway.sandbox.push.apple.com:2195" + // ProductionFeedbackGateway is Apple's feedback service. ProductionFeedbackGateway = "feedback.push.apple.com:2196" - SandboxFeedbackGateway = "feedback.sandbox.push.apple.com:2196" + // SandboxFeedbackGateway is Apple's feedback service for development. + SandboxFeedbackGateway = "feedback.sandbox.push.apple.com:2196" ) // Conn is a wrapper for the actual TLS connections made to Apple -type Conn struct { - NetConn net.Conn - Conf *tls.Config +type Conn interface { + io.ReadWriteCloser + + Connect() error + SetReadDeadline(deadline time.Time) error +} + +type conn struct { + netConn net.Conn + tls *tls.Config gateway string connected bool } +// NewConnWithCert creates a new connection from a certificate. func NewConnWithCert(gw string, cert tls.Certificate) Conn { gatewayParts := strings.Split(gw, ":") - conf := tls.Config{ - Certificates: []tls.Certificate{cert}, - ServerName: gatewayParts[0], + tls := tls.Config{ + Certificates: []tls.Certificate{cert}, + ServerName: gatewayParts[0], + InsecureSkipVerify: true, } - return Conn{gateway: gw, Conf: &conf} + return &conn{gateway: gw, tls: &tls} } -// NewConnWithFiles creates a new Conn from certificate and key in the specified files +// NewConn creates a new Conn from certificate and key pair. func NewConn(gw string, crt string, key string) (Conn, error) { cert, err := tls.X509KeyPair([]byte(crt), []byte(key)) if err != nil { - return Conn{}, err + return &conn{}, err } return NewConnWithCert(gw, cert), nil } -// NewConnWithFiles creates a new Conn from certificate and key in the specified files +// NewConnWithFiles creates a new Conn from certificate and key in the specified files. func NewConnWithFiles(gw string, certFile string, keyFile string) (Conn, error) { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { - return Conn{}, err + return &conn{}, err } return NewConnWithCert(gw, cert), nil } // Connect actually creates the TLS connection -func (c *Conn) Connect() error { +func (c *conn) Connect() error { // Make sure the existing connection is closed - if c.NetConn != nil { - c.NetConn.Close() + if c.netConn != nil { + c.netConn.Close() } - conn, err := net.Dial("tcp", c.gateway) + tlsConn, err := tls.Dial("tcp", c.gateway, c.tls) if err != nil { return err } - tlsConn := tls.Client(conn, c.Conf) - err = tlsConn.Handshake() - if err != nil { - return err - } - - c.NetConn = tlsConn + c.netConn = tlsConn return nil } -func (c *Conn) Close() error { - if c.NetConn != nil { - return c.NetConn.Close() +// Close the connection. +func (c *conn) Close() error { + if c.netConn != nil { + return c.netConn.Close() } return nil } // Read reads data from the connection -func (c *Conn) Read(p []byte) (int, error) { - i, err := c.NetConn.Read(p) - return i, err +func (c *conn) Read(p []byte) (int, error) { + return c.netConn.Read(p) +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *conn) SetReadDeadline(deadline time.Time) error { + return c.netConn.SetReadDeadline(deadline) } // Write writes data from the connection -func (c *Conn) Write(p []byte) (int, error) { - return c.NetConn.Write(p) +func (c *conn) Write(p []byte) (int, error) { + return c.netConn.Write(p) } diff --git a/conn_test.go b/conn_test.go index e910e6c..d388a6d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,232 +1,15 @@ package apns_test import ( - "bytes" - "crypto/tls" - "fmt" - "io" "io/ioutil" - "log" "net" "os" - "strings" - "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/timehop/apns" + "github.com/timehop/tcptest" ) -var DummyCert = `-----BEGIN CERTIFICATE----- -MIIC9TCCAd+gAwIBAgIQf3bEgFWUb+q6eK5ySkV/gjALBgkqhkiG9w0BAQUwEjEQ -MA4GA1UEChMHQWNtZSBDbzAeFw0xNDA2MzAwNDI5MDhaFw0xNTA2MzAwNDI5MDha -MBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK -AoIBAQDhAgWrrFZBtCfVEPg1tSIr9fuSUoeundb556IUr9uOmOHaYK7r3/I43acw -bVIfaenFxwUUf8YakQzTjOa5qSfK/Eylyw2ezBJtNUEqcHw0f+y66+jJbZa4clPa -tL6ezaMS/syXPpvNU8+16jdVdTJzqdBdSGAZMOCeumUWDNdlfBmHPVq1JMy0uGmO -XDoZK2Ir0/3LUfjk9R2wdm1VLrJAml7F0L0FhBHHXgHOSFM2ixjGflffaiuTCxhW -1z1NTo9XjWUQh2iM9Udf+xVnJLGLZ0EMFr2qihuK604Fp4SlNHEF+UWUn+j0PYo+ -LbzM9oKJcdVD0XI36vrn3rGPHO9vAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIAoDAT -BgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuCCWxv -Y2FsaG9zdDALBgkqhkiG9w0BAQUDggEBAGJ/3I4KKlbEwLAC5ut4ZZ9V8WF4sHkI -Lj7e4vx2pPi6hf9miV1ff01NrpfUna7flwL9yD7Ybl7jRRIB4rIcKk+U5djGsT3H -ScGkbIMKrr08drWw1g4JU6PBH7xTfzGxNRERrnmrbJV0jCo9Tt8i53IpPtp6Z2Q1 -8ydtPhU+Bpe2YoNr1w1fSV1JHXqjKV8RlGkCNSi4ozPOO8RbAYnBT3d9XSGoX//q -RGJUf3wC/rCxJkN63Moxuy3vxV2TmiqccHOrXJSJ8P/4PpPV/xuBk5k4HS1Nfmew -d9WHHn6bMJE9arVvWAiu9teCadVffuS2cl2cicN4XB6Ui0aDqhG2Exw= ------END CERTIFICATE-----` - -var DummyKey = `-----BEGIN RSA PRIVATE KEY----- -MIIEpAIBAAKCAQEA4QIFq6xWQbQn1RD4NbUiK/X7klKHrp3W+eeiFK/bjpjh2mCu -69/yON2nMG1SH2npxccFFH/GGpEM04zmuaknyvxMpcsNnswSbTVBKnB8NH/suuvo -yW2WuHJT2rS+ns2jEv7Mlz6bzVPPteo3VXUyc6nQXUhgGTDgnrplFgzXZXwZhz1a -tSTMtLhpjlw6GStiK9P9y1H45PUdsHZtVS6yQJpexdC9BYQRx14BzkhTNosYxn5X -32orkwsYVtc9TU6PV41lEIdojPVHX/sVZySxi2dBDBa9qoobiutOBaeEpTRxBflF -lJ/o9D2KPi28zPaCiXHVQ9FyN+r6596xjxzvbwIDAQABAoIBAFzW+cIA5MJNdFX8 -n32BlGzxHPEd7nAFHmuUwJKqkPwAZsg1NleK2qXOByr7IHRnvhZl7Nmtcu8JRHKR -Y63ddtbRTUrnQmJwL3YyEAZTzVvYILRrnGxoNFU8jw7hnvllPdEbow0QvzZ0S3Lz -BgvTxJJm0dt7fnNGcJftrsHvYHy1dptaR4hPv0xV5G7RPrbTl94llKfi745tp5Wd -xGpnjcBXoAnzCVRij1tHfSYubRJ2MJV0kzG3oVdRV2P/zWaout8BlhLCURv4sRUX -7FfCNa/z+G6AlROjCKJUP9YIUbxBEa/aP8YlSiyLRi1jFbMWcnKWQUdqS19m73Ap -a1LJFPECgYEA+Ve5DegcrWnUb2HsHD38HlmEg6S+/jg2P4TsuLZBtvO4/vzRx/qq -pwuuMm2CsvXr4nVmMEsMlSzYdsnaXIlWqyVDCOwIWR5VYT2GDWqQLaIXPlFaISzN -27tHd64KUtR1fMJUwQVK/MUORUbpYoAnSIil2SlYkWUhF024fNP8CxcCgYEA5wP4 -HLiqU2rqe7vSAF/8fHwPleTzuCfMCVZm0aegUzQQQtklZoVE/BBwEGHdXflq1veq -pHeC8bNR4BF6ZgeSWgbLVF3msquy47QeNElHA2muJd3qmNWz4LXo1Pxb8KXcnXri -QZ+r3Y8obWTFQYq7gGQGPLXGTV3bhLGIyrT4lWkCgYAgZ2MYSJL5gmhmNT6fCPsr -4oxTI2Ti2uFJ7fdppd3ybcgb8zU8HPpyjRUNXqf+o/EM1B78pbQz6skS3vau0fZe -dZA5p5sKIeQMqBc0xSWJmKgWpDHnX9A8/yCxj/+tdgjytrqW/x4YrW9GV4nbEDaK -uZ98EmB9PLxJMAOKzW3S7wKBgQDD4PCy4b3CR2iVC9dva/P5VXQdo+knX884p6M8 -58YgZofXNqnouN2aYRG0QlbiBMcbiRqOo6tK58JnnEpNUuQ8I4Cqg4hGPSHMwv/N -U8i70xLPltABUUpZIcVPOr92WBytBvHrtMiUb3tW7lf3T/vWTHmhZnvDQ+8LH0Ge -pz4T6QKBgQCoBJKOd781IQmT6i5hHSYJlsP6ymaaaQniJPVpnci/jf8+2QtponQY -scgnaBLBasLQ6GfKSRtcyidEi9wwxpVj0tw2p567jeNcIveD0TOYFf0RHEfrs+D4 -VdRgai/v2NbFZLDnzeGVuYypXu6R78isJfHtz/a0aEave8yB3CRiDw== ------END RSA PRIVATE KEY-----` - -// To be able to run in parallel -var mockPort = 50000 - -// Mock Addr -type mockAddr struct { -} - -func (m mockAddr) Network() string { - return "localhost:56789" -} - -func (m mockAddr) String() string { - return "localhost:56789" -} - -// Mock TLS connection -type mockTLSNetConn struct { - bb *bytes.Buffer - err error -} - -func (t mockTLSNetConn) Read(p []byte) (int, error) { - r := bytes.NewReader(t.bb.Bytes()) - return r.Read(p) -} - -func (t mockTLSNetConn) Write(p []byte) (int, error) { - return t.bb.Write(p) -} - -func (t mockTLSNetConn) Close() error { - return t.err -} - -func (m mockTLSNetConn) LocalAddr() net.Addr { - return mockAddr{} -} - -func (m mockTLSNetConn) RemoteAddr() net.Addr { - return mockAddr{} -} - -func (m mockTLSNetConn) SetDeadline(t time.Time) error { - return nil -} - -func (m mockTLSNetConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (m mockTLSNetConn) SetWriteDeadline(t time.Time) error { - return nil -} - -type serverAction struct { - action string - data []byte - cb func(s serverAction) -} - -const ( - readAction = "read" - writeAction = "write" - closeAction = "close" -) - -type mockTLSServer struct { - Port int - Server net.Listener - ConnectionActionGroups [][]serverAction -} - -func (m *mockTLSServer) portStr() string { - if m.Port == 0 { - mockPort = mockPort + 1 - m.Port = mockPort - } - - return fmt.Sprint(m.Port) -} - -func (m *mockTLSServer) Address() string { - return "localhost:" + m.portStr() -} - -func (m *mockTLSServer) start() { - cert, err := tls.X509KeyPair([]byte(DummyCert), []byte(DummyKey)) - if err != nil { - log.Panic(err) - } - - config := tls.Config{Certificates: []tls.Certificate{cert}, ClientAuth: tls.RequireAnyClientCert} - - m.Server, err = tls.Listen("tcp", "localhost:"+m.portStr(), &config) - go func() { - for i := 0; i < len(m.ConnectionActionGroups); i++ { - g := m.ConnectionActionGroups[i] - - // Wait for a connection. - conn, err := m.Server.Accept() - if err != nil { - if strings.Contains(err.Error(), "use of closed network connection") { - return - } else { - log.Fatal(err) - } - } - // Handle the connection in a new goroutine. - // The loop then returns to accepting, so that - // multiple connections may be served concurrently. - go func(c net.Conn) { - for j := 0; j < len(g); j++ { - a := g[j] - switch a.action { - case readAction: - c.Read(a.data) - case writeAction: - c.Write(a.data) - case closeAction: - c.Close() - - if a.cb != nil { - a.cb(a) - } - return - } - - if a.cb != nil { - a.cb(a) - } - } - }(conn) - } - - // No more connection action groups - }() -} - -func (m *mockTLSServer) stop() { - if m.Server != nil { - m.Server.Close() - } -} - -var withMockServer = func(as [][]serverAction, cb func(s *mockTLSServer)) { - d := make(chan interface{}) - withMockServerAsync(as, d, func(s *mockTLSServer) { - cb(s) - close(d) - }) -} - -var withMockServerAsync = func(as [][]serverAction, d chan interface{}, cb func(s *mockTLSServer)) { - s := &mockTLSServer{} - s.ConnectionActionGroups = as - - s.start() - - cb(s) - - <-d - s.stop() -} - // Tests var _ = Describe("Conn", func() { Describe(".NewConn", func() { @@ -239,7 +22,7 @@ var _ = Describe("Conn", func() { Context("valid key/cert pair", func() { It("should not return an error", func() { - _, err := apns.NewConn(apns.SandboxGateway, DummyCert, DummyKey) + _, err := apns.NewConn(apns.SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) }) }) @@ -259,11 +42,11 @@ var _ = Describe("Conn", func() { BeforeEach(func() { certFile, _ = ioutil.TempFile("", "cert.pem") - certFile.Write([]byte(DummyCert)) + certFile.Write([]byte(tcptest.LocalhostCert)) certFile.Close() keyFile, _ = ioutil.TempFile("", "key.pem") - keyFile.Write([]byte(DummyKey)) + keyFile.Write([]byte(tcptest.LocalhostKey)) keyFile.Close() }) @@ -295,117 +78,122 @@ var _ = Describe("Conn", func() { }) Context("server up", func() { - as := [][]serverAction{[]serverAction{serverAction{action: readAction, data: []byte{}}}} - Context("with untrusted certs", func() { It("should return an error", func(d Done) { - withMockServer(as, func(s *mockTLSServer) { - conn, _ := apns.NewConn(s.Address(), DummyCert, DummyKey) - err := conn.Connect() - Expect(err).NotTo(BeNil()) + s := tcptest.NewTLSServer(func(c net.Conn) {}) + defer s.Close() - close(d) - }) + conn, err := apns.NewConn(s.Addr, "not trusted", "not even a little") + Expect(err).NotTo(BeNil()) + + err = conn.Connect() + Expect(err).NotTo(BeNil()) + + close(d) }) }) Context("trusting the certs", func() { It("should not return an error", func(d Done) { - withMockServer(as, func(s *mockTLSServer) { - conn, _ := apns.NewConn(s.Address(), DummyCert, DummyKey) - conn.Conf.InsecureSkipVerify = true + s := tcptest.NewUnstartedServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect + }) - err := conn.Connect() - Expect(err).To(BeNil()) + s.StartTLS() + defer s.Close() - close(d) - }) - }) + conn, err := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = conn.Connect() + Expect(err).To(BeNil()) + + close(d) + }, 10) }) Context("with existing connection", func() { It("should not return an error", func(d Done) { - as = [][]serverAction{ - []serverAction{serverAction{action: readAction, data: []byte{}}}, - []serverAction{serverAction{action: readAction, data: []byte{}}}, - } + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect + }) + defer s.Close() - withMockServer(as, func(s *mockTLSServer) { - conn, _ := apns.NewConn(s.Address(), DummyCert, DummyKey) - conn.Conf.InsecureSkipVerify = true + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - conn.Connect() + conn.Connect() - err := conn.Connect() - Expect(err).To(BeNil()) + err := conn.Connect() + Expect(err).To(BeNil()) - close(d) - }) - }) + close(d) + }, 10) }) }) }) Describe("#Read", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte("hello!"))} - - pp := make([]byte, 6) - bytes.NewReader(rwc.bb.Bytes()).Read(pp) + It("should read out 'hello!'", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte("hello!")) + }) + defer s.Close() - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() - It("should read out 'hello!'", func() { p := make([]byte, 6) conn.Read(p) Expect(p).To(Equal([]byte("hello!"))) }) }) +}) - Describe("#Write", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte{})} +var _ = Describe("#Write", func() { + It("should read out 'hello!'", func(d Done) { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + b := make([]byte, 6) + c.Read(b) - It("should write out 'world!'", func() { - conn.Write([]byte("world!")) - Expect(rwc.bb.String()).To(Equal("world!")) + Expect(string(b)).To(Equal("hello!")) + close(d) }) - }) - Describe("#Close", func() { - Context("with connection", func() { - Context("no error", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte{})} + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + conn.Write([]byte("hello!")) + }, 10) +}) - It("should return no error", func() { - Expect(rwc.Close()).To(BeNil()) +var _ = Describe("#Close", func() { + Context("with connection", func() { + Context("no error", func() { + It("should return no error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect }) - }) - - Context("with error", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte{})} - - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + defer s.Close() - rwc.err = io.EOF - It("should return that error", func() { - Expect(rwc.Close()).To(Equal(io.EOF)) - }) + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() + Expect(conn.Close()).To(BeNil()) }) }) + }) - Context("without connection", func() { - c, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - It("should not return an error", func() { - Expect(c.Close()).To(BeNil()) - }) + Context("without connection", func() { + It("should not return an error", func() { + conn, _ := apns.NewConn("localhost:12345", string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(conn.Close()).To(BeNil()) }) }) }) diff --git a/doc.go b/doc.go index 02149d7..85b219e 100644 --- a/doc.go +++ b/doc.go @@ -1,13 +1,12 @@ /* -A Go package to interface with the Apple Push -Notification Service +Package apns provides an interface with the Apple Push Notification Service. Features This library implements a few features that we couldn't find in any one library elsewhere: - Long Lived Clients - Apple's documentation say that you should hold a + Long Lived Clients - Apple's documentation says that you should hold a persistent connection open and not create new connections for every payload See: https://developer.apple.com/library/ios/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/Chapters/CommunicatingWIthAPS.html#//apple_ref/doc/uid/TP40008194-CH101-SW6) @@ -16,7 +15,7 @@ library elsewhere: variable length payloads. This library uses that protocol. - Robust Send Guarantees - APNS has asynchronous feedback on whether a push + Robust Send Guarantees - Apple has asynchronous feedback on whether a push sent. That means that if you send pushes after a bad send, those pushes will be lost forever. Our library records the last N pushes, detects errors, diff --git a/error.go b/error.go index 5425868..20b390b 100644 --- a/error.go +++ b/error.go @@ -5,19 +5,35 @@ import ( "encoding/binary" ) +var ( + // ErrUnrecognizedErrorResponse when the error from Apple isn't recognized. + ErrUnrecognizedErrorResponse = "Unrecognized error or no error." +) + const ( // Error strings based on the codes specified here: // https://developer.apple.com/library/ios/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/Chapters/CommunicatingWIthAPS.html#//apple_ref/doc/uid/TP40008194-CH101-SW12 - ErrProcessing = "Processing error" + + // ErrProcessing (1) + ErrProcessing = "Processing error" + // ErrMissingDeviceToken (2) ErrMissingDeviceToken = "Missing device token" - ErrMissingTopic = "Missing topic" - ErrMissingPayload = "Missing payload" - ErrInvalidTokenSize = "Invalid token size" - ErrInvalidTopicSize = "Invalid topic size" + // ErrMissingTopic (3) + ErrMissingTopic = "Missing topic" + // ErrMissingPayload (4) when no payload. + ErrMissingPayload = "Missing payload" + // ErrInvalidTokenSize (5) for a device token that is the wrong size. + ErrInvalidTokenSize = "Invalid token size" + // ErrInvalidTopicSize (6) + ErrInvalidTopicSize = "Invalid topic size" + // ErrInvalidPayloadSize (7) for a payload over 2 KB. ErrInvalidPayloadSize = "Invalid payload size" - ErrInvalidToken = "Invalid token" - ErrShutdown = "Shutdown" - ErrUnknown = "None (unknown)" + // ErrInvalidToken (8) such as a production device token used with the sandbox gateway. + ErrInvalidToken = "Invalid token" + // ErrShutdown (10) closed connection to perform maintenance. Open a new connection. + ErrShutdown = "Shutdown" + // ErrUnknown (255) + ErrUnknown = "None (unknown)" ) var errorMapping = map[uint8]string{ @@ -33,16 +49,18 @@ var errorMapping = map[uint8]string{ 255: ErrUnknown, } +// Error captures the error response read from Apple's Push Notification server. type Error struct { - Command uint8 + Command uint8 // always should be 8 Status uint8 Identifier uint32 ErrStr string } +// NewError parses an error from Apple. func NewError(p []byte) Error { - if len(p) != 1+1+4 { - return Error{ErrStr: ErrUnknown} + if len(p) != 6 { + return Error{ErrStr: ErrUnrecognizedErrorResponse} } r := bytes.NewBuffer(p) @@ -54,7 +72,7 @@ func NewError(p []byte) Error { var ok bool if e.ErrStr, ok = errorMapping[e.Status]; !ok { - e.ErrStr = ErrUnknown + e.ErrStr = ErrUnrecognizedErrorResponse } return e diff --git a/error_test.go b/error_test.go index 1c6bb78..12d063a 100644 --- a/error_test.go +++ b/error_test.go @@ -79,14 +79,14 @@ var _ = Describe("Error", func() { }) Context("error with unrecognized code", func() { - ShouldBeErrorWithErrStr(300, apns.ErrUnknown) + ShouldBeErrorWithErrStr(300, apns.ErrUnrecognizedErrorResponse) }) Context("not enough bytes", func() { - It("should be ErrUnknown", func() { + It("should be ErrUnrecognizedErrorResponse", func() { e := apns.NewError([]byte{}) Expect(e).NotTo(BeNil()) - Expect(e.ErrStr).To(Equal(apns.ErrUnknown)) + Expect(e.ErrStr).To(Equal(apns.ErrUnrecognizedErrorResponse)) }) }) }) diff --git a/example/example.go b/example/example.go index 1b670ac..637d0af 100644 --- a/example/example.go +++ b/example/example.go @@ -10,10 +10,10 @@ import ( func main() { c, err := apns.NewClientWithFiles(apns.ProductionGateway, "apns.crt", "apns.key") if err != nil { - log.Fatal("Could not create client", err.Error()) + log.Fatal("Could not create client: ", err.Error()) } - i := 0 + i := 1 for { fmt.Print("Enter ' ': ") diff --git a/feedback.go b/feedback.go index 488bf1b..80d0e8a 100644 --- a/feedback.go +++ b/feedback.go @@ -8,10 +8,12 @@ import ( "time" ) +// Feedback is a connection to Apple's feedback service. type Feedback struct { - Conn *Conn + Conn Conn } +// FeedbackTuple represents the feedback received from Apple. type FeedbackTuple struct { Timestamp time.Time TokenLength uint16 @@ -37,28 +39,31 @@ func feedbackTupleFromBytes(b []byte) FeedbackTuple { } } +// NewFeedbackWithCert creates a new feedback service client with a certificate. func NewFeedbackWithCert(gw string, cert tls.Certificate) Feedback { conn := NewConnWithCert(gw, cert) - return Feedback{Conn: &conn} + return Feedback{Conn: conn} } +// NewFeedback creates a new feedback service client with a certificate/key pair. func NewFeedback(gw string, cert string, key string) (Feedback, error) { conn, err := NewConn(gw, cert, key) if err != nil { return Feedback{}, err } - return Feedback{Conn: &conn}, nil + return Feedback{Conn: conn}, nil } +// NewFeedbackWithFiles creates a new feedback service client from certificate and key files. func NewFeedbackWithFiles(gw string, certFile string, keyFile string) (Feedback, error) { conn, err := NewConnWithFiles(gw, certFile, keyFile) if err != nil { return Feedback{}, err } - return Feedback{Conn: &conn}, nil + return Feedback{Conn: conn}, nil } // Receive returns a read only channel for APNs feedback. The returned channel @@ -80,9 +85,13 @@ func (f Feedback) receive(fc chan FeedbackTuple) { for { b := make([]byte, 38) - f.Conn.NetConn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + err = f.Conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + if err != nil { + close(fc) + return + } - _, err := f.Conn.Read(b) + _, err = f.Conn.Read(b) if err != nil { close(fc) return diff --git a/feedback_test.go b/feedback_test.go index 29978b4..59739e2 100644 --- a/feedback_test.go +++ b/feedback_test.go @@ -4,12 +4,15 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "io" "io/ioutil" + "net" "os" - "time" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/timehop/apns" + "github.com/timehop/tcptest" ) var _ = Describe("Feedback", func() { @@ -23,7 +26,7 @@ var _ = Describe("Feedback", func() { Context("valid cert/key pair", func() { It("should create a valid client", func() { - _, err := apns.NewFeedback(apns.ProductionGateway, DummyCert, DummyKey) + _, err := apns.NewFeedback(apns.SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) }) }) @@ -42,11 +45,11 @@ var _ = Describe("Feedback", func() { BeforeEach(func() { certFile, _ = ioutil.TempFile("", "cert.pem") - certFile.Write([]byte(DummyCert)) + certFile.Write([]byte(tcptest.LocalhostCert)) certFile.Close() keyFile, _ = ioutil.TempFile("", "key.pem") - keyFile.Write([]byte(DummyKey)) + keyFile.Write([]byte(tcptest.LocalhostKey)) keyFile.Close() }) @@ -70,16 +73,18 @@ var _ = Describe("Feedback", func() { Describe("#Receive", func() { Context("could not connect", func() { It("should not receive anything", func() { - s := &mockTLSServer{} - - f, _ := apns.NewFeedback(s.Address(), DummyCert, DummyKey) - f.Conn.Conf.InsecureSkipVerify = true + m := mockConn{ + connect: func() error { + return io.EOF + }, + } + f := apns.Feedback{Conn: &m} c := f.Receive() r := 0 for _ = range c { - r += 1 + r++ } Expect(r).To(Equal(0)) @@ -87,89 +92,87 @@ var _ = Describe("Feedback", func() { }) Context("times out", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - }, - } - - withMockServer(as, func(s *mockTLSServer) { - f, _ := apns.NewFeedback(s.Address(), DummyCert, DummyKey) - f.Conn.Conf.InsecureSkipVerify = true - - It("should not receive anything", func() { - c := f.Receive() - - r := 0 - for _ = range c { - r += 1 - } - - Expect(r).To(Equal(0)) - }) + It("should not receive anything", func() { + m := mockConn{ + read: func(b []byte) (int, error) { + return 0, net.UnknownNetworkError("") + }, + } + + f := apns.Feedback{Conn: &m} + c := f.Receive() + + r := 0 + for _ = range c { + r++ + } + + Expect(r).To(Equal(0)) }) }) + }) + + Context("with feedback", func() { + f1 := bytes.NewBuffer([]byte{}) + f2 := bytes.NewBuffer([]byte{}) + f3 := bytes.NewBuffer([]byte{}) - Context("with feedback", func() { - f1 := bytes.NewBuffer([]byte{}) - f2 := bytes.NewBuffer([]byte{}) - f3 := bytes.NewBuffer([]byte{}) - - // The final token strings - t1 := "00a18269661e9406aea59a5620b05c7c0e371574fa6f251951de8d7a5a292535" - t2 := "00a1a4b7294fcfbc5293f63d4298fcecd9c20a893befd45adceead5fc92d3319" - t3 := "00a1b7893d5e85eb8bb7bf0846b464d075248555118ae893b06e96cfb8d678e3" - - bt1, _ := hex.DecodeString(t1) - bt2, _ := hex.DecodeString(t2) - bt3, _ := hex.DecodeString(t3) - - binary.Write(f1, binary.BigEndian, uint32(1404358249)) - binary.Write(f1, binary.BigEndian, uint16(len(bt1))) - binary.Write(f1, binary.BigEndian, bt1) - - binary.Write(f2, binary.BigEndian, uint32(1404352249)) - binary.Write(f2, binary.BigEndian, uint16(len(bt2))) - binary.Write(f2, binary.BigEndian, bt2) - - binary.Write(f3, binary.BigEndian, uint32(1394352249)) - binary.Write(f3, binary.BigEndian, uint16(len(bt3))) - binary.Write(f3, binary.BigEndian, bt3) - - as := [][]serverAction{ - []serverAction{ - serverAction{action: writeAction, data: f1.Bytes()}, - serverAction{action: writeAction, data: f2.Bytes()}, - serverAction{action: writeAction, data: f3.Bytes()}, - }, - } - - It("should receive feedback", func(d Done) { - withMockServer(as, func(s *mockTLSServer) { - f, _ := apns.NewFeedback(s.Address(), DummyCert, DummyKey) - f.Conn.Conf.InsecureSkipVerify = true - - c := f.Receive() - - r1 := <-c - Expect(r1.Timestamp).To(Equal(time.Unix(1404358249, 0))) - Expect(r1.TokenLength).To(Equal(uint16(len(bt1)))) - Expect(r1.DeviceToken).To(Equal(t1)) - - r2 := <-c - Expect(r2.Timestamp).To(Equal(time.Unix(1404352249, 0))) - Expect(r2.TokenLength).To(Equal(uint16(len(bt2)))) - Expect(r2.DeviceToken).To(Equal(t2)) - - r3 := <-c - Expect(r3.Timestamp).To(Equal(time.Unix(1394352249, 0))) - Expect(r3.TokenLength).To(Equal(uint16(len(bt3)))) - Expect(r3.DeviceToken).To(Equal(t3)) - - <-c - close(d) - }) + // The final token strings + t1 := "00a18269661e9406aea59a5620b05c7c0e371574fa6f251951de8d7a5a292535" + t2 := "00a1a4b7294fcfbc5293f63d4298fcecd9c20a893befd45adceead5fc92d3319" + t3 := "00a1b7893d5e85eb8bb7bf0846b464d075248555118ae893b06e96cfb8d678e3" + + bt1, _ := hex.DecodeString(t1) + bt2, _ := hex.DecodeString(t2) + bt3, _ := hex.DecodeString(t3) + + binary.Write(f1, binary.BigEndian, uint32(1404358249)) + binary.Write(f1, binary.BigEndian, uint16(len(bt1))) + binary.Write(f1, binary.BigEndian, bt1) + + binary.Write(f2, binary.BigEndian, uint32(1404352249)) + binary.Write(f2, binary.BigEndian, uint16(len(bt2))) + binary.Write(f2, binary.BigEndian, bt2) + + binary.Write(f3, binary.BigEndian, uint32(1394352249)) + binary.Write(f3, binary.BigEndian, uint16(len(bt3))) + binary.Write(f3, binary.BigEndian, bt3) + + It("should receive feedback", func(d Done) { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write(f1.Bytes()) + c.Write(f2.Bytes()) + c.Write(f3.Bytes()) + + // TODO(bw) this doesn't seem right + c.Write([]byte{0}) + + c.Close() }) - }) + defer s.Close() + + f, err := apns.NewFeedback(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + c := f.Receive() + + r1 := <-c + Expect(r1.Timestamp.Unix()).To(Equal(int64(1404358249))) + Expect(r1.TokenLength).To(Equal(uint16(len(bt1)))) + Expect(r1.DeviceToken).To(Equal(t1)) + + r2 := <-c + Expect(r2.Timestamp.Unix()).To(Equal(int64(1404352249))) + Expect(r2.TokenLength).To(Equal(uint16(len(bt2)))) + Expect(r2.DeviceToken).To(Equal(t2)) + + r3 := <-c + Expect(r3.Timestamp.Unix()).To(Equal(int64(1394352249))) + Expect(r3.TokenLength).To(Equal(uint16(len(bt3)))) + Expect(r3.DeviceToken).To(Equal(t3)) + + <-c + close(d) + }, 10) }) }) diff --git a/notification.go b/notification.go index de1b6ff..1dd8d5d 100644 --- a/notification.go +++ b/notification.go @@ -11,10 +11,16 @@ import ( ) const ( - PriorityImmediate = 10 + // PriorityImmediate for time sensitive notifications (not for silent push messages). + PriorityImmediate = 10 + // PriorityPowerConserve for notifications that are less time sensitive. PriorityPowerConserve = 5 ) +const ( + validDeviceTokenLength = 64 +) + const ( commandID = 2 @@ -32,13 +38,8 @@ const ( priorityItemLength = 1 ) -type NotificationResult struct { - Notif Notification - Err Error -} - +// Alert to send. type Alert struct { - // Do not add fields without updating the implementation of isZero. Body string `json:"body,omitempty"` Title string `json:"title,omitempty"` Action string `json:"action,omitempty"` @@ -46,8 +47,11 @@ type Alert struct { LocArgs []string `json:"loc-args,omitempty"` ActionLocKey string `json:"action-loc-key,omitempty"` LaunchImage string `json:"launch-image,omitempty"` + + // Do not add fields without updating the implementation of isZero. } +// isSimple alerts only contain a Body. func (a *Alert) isSimple() bool { return len(a.Title) == 0 && len(a.Action) == 0 && len(a.LocKey) == 0 && len(a.LocArgs) == 0 && len(a.ActionLocKey) == 0 && len(a.LaunchImage) == 0 } @@ -56,6 +60,7 @@ func (a *Alert) isZero() bool { return a.isSimple() && len(a.Body) == 0 } +// APS is the Apple-reserved aps namespace in a push notification. type APS struct { Alert Alert Badge *int // 0 to clear notifications, nil to leave as is. @@ -65,6 +70,7 @@ type APS struct { Category string // requires iOS 8+ } +// MarshalJSON implements the json.Marshaler interface. func (aps APS) MarshalJSON() ([]byte, error) { data := make(map[string]interface{}) @@ -94,6 +100,7 @@ func (aps APS) MarshalJSON() ([]byte, error) { return json.Marshal(data) } +// Payload to send Apple. type Payload struct { APS APS // MDM for mobile device management @@ -101,6 +108,7 @@ type Payload struct { customValues map[string]interface{} } +// MarshalJSON implements the json.Marshaler interface. func (p *Payload) MarshalJSON() ([]byte, error) { if len(p.MDM) != 0 { p.customValues["mdm"] = p.MDM @@ -111,6 +119,7 @@ func (p *Payload) MarshalJSON() ([]byte, error) { return json.Marshal(p.customValues) } +// SetCustomValue sets a custom payload value. func (p *Payload) SetCustomValue(key string, value interface{}) error { if key == "aps" { return errors.New("cannot assign a custom APS value in payload") @@ -121,6 +130,7 @@ func (p *Payload) SetCustomValue(key string, value interface{}) error { return nil } +// Notification contains the payload. type Notification struct { ID string DeviceToken string @@ -130,23 +140,33 @@ type Notification struct { Payload *Payload } +// NewNotification creates a new notification. func NewNotification() Notification { return Notification{Payload: NewPayload()} } +// NewPayload creates a new payload. func NewPayload() *Payload { return &Payload{customValues: map[string]interface{}{}} } +// ToBinary encodes a notification to send it. func (n Notification) ToBinary() ([]byte, error) { b := []byte{} + if len(n.DeviceToken) != validDeviceTokenLength { + return b, errors.New(ErrInvalidToken) + } + binTok, err := hex.DecodeString(n.DeviceToken) if err != nil { return b, fmt.Errorf("convert token to hex error: %s", err) } - j, _ := json.Marshal(n.Payload) + j, err := json.Marshal(n.Payload) + if err != nil { + return b, fmt.Errorf("json marshal error: %s", err) + } buf := bytes.NewBuffer(b) diff --git a/notification_test.go b/notification_test.go index f0bf6cf..da77814 100644 --- a/notification_test.go +++ b/notification_test.go @@ -193,7 +193,20 @@ var _ = Describe("Notifications", func() { Describe("#ToBinary", func() { Context("invalid token format", func() { n := apns.NewNotification() - n.DeviceToken = "totally not a valid token" + n.DeviceToken = "totally not a valid token length" + + It("should return an error", func() { + _, err := n.ToBinary() + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(Equal(apns.ErrInvalidToken)) + }) + + // Expect(err.Error()).To(ContainSubstring("convert token to hex error")) + }) + + Context("non-convertable token", func() { + n := apns.NewNotification() + n.DeviceToken = "123456789012345678901234567890zz123456789012345678901234567890zz" It("should return an error", func() { _, err := n.ToBinary() diff --git a/session.go b/session.go new file mode 100644 index 0000000..5a5fc3b --- /dev/null +++ b/session.go @@ -0,0 +1,144 @@ +package apns + +import ( + "errors" + "io" + "sync" +) + +// Session to Apple's Push Notification server. +type Session interface { + Send(n Notification) error + Connect() error + RequeueableNotifications() []Notification + Disconnect() + Disconnected() bool +} + +type sessionState int + +const ( + sessionStateNew sessionState = 1 << iota + sessionStateConnected + sessionStateDisconnected +) + +type session struct { + b *buffer + + conn Conn + connm sync.Mutex + + st sessionState + stm sync.Mutex + + err Error +} + +// NewSession creates a new session. +func NewSession(conn Conn) Session { + return &session{ + st: sessionStateNew, + stm: sync.Mutex{}, + conn: conn, + connm: sync.Mutex{}, + b: newBuffer(50), + } +} + +// Connect session to gateway. +func (s *session) Connect() error { + if s.isNew() { + return errors.New("can't connect unless the session is new") + } + + go s.readErrors() + return nil +} + +func (s *session) isNew() bool { + s.stm.Lock() + defer s.stm.Unlock() + + return s.st != sessionStateNew +} + +// Disconnected indicates whether session is disconnected. +func (s *session) Disconnected() bool { + s.stm.Lock() + defer s.stm.Unlock() + + return s.st == sessionStateDisconnected +} + +func (s *session) RequeueableNotifications() []Notification { + return s.b.RequeueableNotifications(s.err.Identifier) +} + +// Connected indicates whether session is connected. +func (s *session) Connected() bool { + s.stm.Lock() + defer s.stm.Unlock() + + return s.st == sessionStateConnected +} + +// Send notification to gateway. +func (s *session) Send(n Notification) error { + // If disconnected, error out + if !s.Connected() { + return errors.New("not connected") + } + + // Serialize + b, err := n.ToBinary() + if err != nil { + return err + } + + // Add to buffer + s.b.Add(n) + + // Send synchronously + return s.write(b) +} + +func (s *session) write(b []byte) error { + s.connm.Lock() + defer s.connm.Unlock() + + _, err := s.conn.Write(b) + if err == io.EOF { + s.Disconnect() + return err + } + + return err +} + +// Disconnect from gateway. +func (s *session) Disconnect() { + s.transitionState(sessionStateDisconnected) +} + +func (s *session) transitionState(st sessionState) { + s.stm.Lock() + defer s.stm.Unlock() + + s.st = st +} + +func (s *session) readErrors() { + p := make([]byte, 6, 6) + + _, err := s.conn.Read(p) + // TODO(bw) not sure what to do here. It's unclear what errors + // come out of this and how we handle it. + if err != nil { + return + } + + s.Disconnect() + + s.err = NewError(p) +} diff --git a/session_test.go b/session_test.go new file mode 100644 index 0000000..e70cbfb --- /dev/null +++ b/session_test.go @@ -0,0 +1,92 @@ +package apns + +import ( + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type mockConn struct{} + +func (m mockConn) Read(b []byte) (int, error) { + return 0, nil +} + +func (m mockConn) Write(b []byte) (int, error) { + return 0, nil +} + +func (m mockConn) Close() error { + return nil +} + +func (m mockConn) Connect() error { + return nil +} + +func (m mockConn) SetReadDeadline(deadline time.Time) error { + return nil +} + +var _ = Describe("Session", func() { + Describe("NewSession", func() { + It("creates a session", func() { + s := NewSession(mockConn{}) + Expect(s).NotTo(BeNil()) + }) + }) + + Describe("Connect", func() { + Context("new state", func() { + It("should not return an error", func() { + s := NewSession(mockConn{}) + + err := s.Connect() + Expect(err).To(BeNil()) + }) + }) + + Context("not new state", func() { + It("should return an error", func() { + sess := NewSession(mockConn{}) + + s := sess.(*session) + s.transitionState(sessionStateDisconnected) + + err := s.Connect() + Expect(err).NotTo(BeNil()) + }) + }) + }) + + Describe("Disconnected", func() { + Context("not connected", func() { + It("should not be true", func() { + sess := NewSession(mockConn{}) + + s := sess.(*session) + s.transitionState(sessionStateDisconnected) + + Expect(s.Disconnected()).To(BeTrue()) + }) + }) + + Context("connected", func() { + It("should be false", func() { + sess := NewSession(mockConn{}) + + s := sess.(*session) + s.Connect() + + Expect(s.Disconnected()).To(BeFalse()) + }) + }) + }) + + Describe("Send", func() { + }) + + Describe("Disconnect", func() { + }) +})