From f985ea7765508a78945ded6d58d68e1bb959e694 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 10 Dec 2014 23:44:53 -0500 Subject: [PATCH 01/32] Start performing a lot more error checking in Send() --- client.go | 72 ++++++++++++++++++++++++++++++++----------------- notification.go | 8 ++++++ 2 files changed, 55 insertions(+), 25 deletions(-) diff --git a/client.go b/client.go index de7ab1a..b9a31f1 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "io" "log" + "sync" "time" ) @@ -27,20 +28,28 @@ func (b *buffer) Add(v interface{}) *list.Element { return e } +type serializedNotif struct { + id uint32 + b []byte +} + type Client struct { Conn *Conn FailedNotifs chan NotificationResult - notifs chan Notification - id uint32 + notifs chan serializedNotif + + id uint32 + idm sync.Mutex } func newClientWithConn(gw string, conn Conn) Client { c := Client{ Conn: &conn, FailedNotifs: make(chan NotificationResult), - id: uint32(1), - notifs: make(chan Notification), + notifs: make(chan serializedNotif), + id: 1, + idm: sync.Mutex{}, } go c.runLoop() @@ -73,10 +82,37 @@ func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, err } func (c *Client) Send(n Notification) error { - c.notifs <- n + // Set identifier if not specified + if n.Identifier == 0 { + n.Identifier = c.nextID() + } else if c.id < n.Identifier { + c.setID(n.Identifier) + } + + b, err := n.ToBinary() + if err != nil { + return err + } + + c.notifs <- serializedNotif{b: b, id: n.Identifier} return nil } +func (c *Client) setID(n uint32) { + c.idm.Lock() + defer c.idm.Unlock() + + c.id = n +} + +func (c *Client) nextID() uint32 { + c.idm.Lock() + defer c.idm.Unlock() + + c.id++ + return c.id +} + func (c *Client) reportFailedPush(v interface{}, err *Error) { failedNotif, ok := v.(Notification) if !ok || v == nil { @@ -93,7 +129,7 @@ func (c *Client) requeue(cursor *list.Element) { // If `cursor` is not nil, this means there are notifications that // need to be delivered (or redelivered) for ; cursor != nil; cursor = cursor.Next() { - if n, ok := cursor.Value.(Notification); ok { + if n, ok := cursor.Value.(serializedNotif); ok { go func() { c.notifs <- n }() } } @@ -103,11 +139,11 @@ func (c *Client) handleError(err *Error, buffer *buffer) *list.Element { cursor := buffer.Back() for cursor != nil { - // Get notification - n, _ := cursor.Value.(Notification) + // Get serialized notification + n, _ := cursor.Value.(serializedNotif) // If the notification, move cursor after the trouble notification - if n.Identifier == err.Identifier { + if n.id == err.Identifier { go c.reportFailedPush(cursor.Value, err) next := cursor.Next() @@ -143,7 +179,7 @@ func (c *Client) runLoop() { // Connection open, listen for notifs and errors for { var err error - var n Notification + var n serializedNotif // Check for notifications or errors. There is a chance we'll send notifications // if we already have an error since `select` will "pseudorandomly" choose a @@ -169,21 +205,7 @@ func (c *Client) runLoop() { // Add to list cursor = sent.Add(n) - // Set identifier if not specified - if n.Identifier == 0 { - n.Identifier = c.id - c.id++ - } else if c.id < n.Identifier { - c.id = n.Identifier + 1 - } - - b, err := n.ToBinary() - if err != nil { - // TODO - continue - } - - _, err = c.Conn.Write(b) + _, err = c.Conn.Write(n.b) if err == io.EOF { log.Println("EOF trying to write notification") diff --git a/notification.go b/notification.go index de1b6ff..7fe827e 100644 --- a/notification.go +++ b/notification.go @@ -15,6 +15,10 @@ const ( PriorityPowerConserve = 5 ) +const ( + validDeviceTokenLength = 64 +) + const ( commandID = 2 @@ -141,6 +145,10 @@ func NewPayload() *Payload { func (n Notification) ToBinary() ([]byte, error) { b := []byte{} + if len(n.DeviceToken) != validDeviceTokenLength { + return b, errors.New(ErrInvalidToken) + } + binTok, err := hex.DecodeString(n.DeviceToken) if err != nil { return b, fmt.Errorf("convert token to hex error: %s", err) From 350d3e323460376e1b32dd494e0247b43f5eb8cf Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 10 Dec 2014 23:58:05 -0500 Subject: [PATCH 02/32] Pass the original notification through if we want it with NotificationResult --- client.go | 14 +++++--------- notification.go | 2 +- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index b9a31f1..fdc62ae 100644 --- a/client.go +++ b/client.go @@ -31,6 +31,7 @@ func (b *buffer) Add(v interface{}) *list.Element { type serializedNotif struct { id uint32 b []byte + n *Notification } type Client struct { @@ -94,7 +95,7 @@ func (c *Client) Send(n Notification) error { return err } - c.notifs <- serializedNotif{b: b, id: n.Identifier} + c.notifs <- serializedNotif{b: b, id: n.Identifier, n: &n} return nil } @@ -113,14 +114,9 @@ func (c *Client) nextID() uint32 { return c.id } -func (c *Client) reportFailedPush(v interface{}, err *Error) { - failedNotif, ok := v.(Notification) - if !ok || v == nil { - return - } - +func (c *Client) reportFailedPush(s serializedNotif, err *Error) { select { - case c.FailedNotifs <- NotificationResult{Notif: failedNotif, Err: *err}: + case c.FailedNotifs <- NotificationResult{Notif: s.n, Err: *err}: default: } } @@ -144,7 +140,7 @@ func (c *Client) handleError(err *Error, buffer *buffer) *list.Element { // If the notification, move cursor after the trouble notification if n.id == err.Identifier { - go c.reportFailedPush(cursor.Value, err) + go c.reportFailedPush(n, err) next := cursor.Next() diff --git a/notification.go b/notification.go index 7fe827e..bfc198b 100644 --- a/notification.go +++ b/notification.go @@ -37,7 +37,7 @@ const ( ) type NotificationResult struct { - Notif Notification + Notif *Notification Err Error } From 7ce2c474e8491d3d53e92fbb41a4b276ca3e918a Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Thu, 11 Dec 2014 19:07:49 -0500 Subject: [PATCH 03/32] Undo NotificationResult api change --- client.go | 2 +- notification.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index fdc62ae..2d27b5d 100644 --- a/client.go +++ b/client.go @@ -116,7 +116,7 @@ func (c *Client) nextID() uint32 { func (c *Client) reportFailedPush(s serializedNotif, err *Error) { select { - case c.FailedNotifs <- NotificationResult{Notif: s.n, Err: *err}: + case c.FailedNotifs <- NotificationResult{Notif: *s.n, Err: *err}: default: } } diff --git a/notification.go b/notification.go index bfc198b..7fe827e 100644 --- a/notification.go +++ b/notification.go @@ -37,7 +37,7 @@ const ( ) type NotificationResult struct { - Notif *Notification + Notif Notification Err Error } From f162ceda2d0ca69f208b096a62350d54cda271be Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Mon, 15 Dec 2014 21:06:02 -0500 Subject: [PATCH 04/32] WIP --- client.go | 62 ++++++++++++++++++++++++++++++++++++++----------------- error.go | 5 +++++ 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index 2d27b5d..69a135f 100644 --- a/client.go +++ b/client.go @@ -6,7 +6,6 @@ import ( "io" "log" "sync" - "time" ) type buffer struct { @@ -28,7 +27,7 @@ func (b *buffer) Add(v interface{}) *list.Element { return e } -type serializedNotif struct { +type serialized struct { id uint32 b []byte n *Notification @@ -38,23 +37,26 @@ type Client struct { Conn *Conn FailedNotifs chan NotificationResult - notifs chan serializedNotif + notifs chan serialized id uint32 idm sync.Mutex + + connected bool + connm sync.Mutex } func newClientWithConn(gw string, conn Conn) Client { c := Client{ Conn: &conn, FailedNotifs: make(chan NotificationResult), - notifs: make(chan serializedNotif), + notifs: make(chan serialized), id: 1, idm: sync.Mutex{}, + connected: false, + connm: sync.Mutex{}, } - go c.runLoop() - return c } @@ -82,7 +84,21 @@ func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, err return newClientWithConn(gw, conn), nil } +func (c *Client) Connect() error { + err := c.Conn.Connect() + if err != nil { + return err + } + + go c.runLoop() + return nil +} + func (c *Client) Send(n Notification) error { + if !c.connected { + return ErrDisconnected + } + // Set identifier if not specified if n.Identifier == 0 { n.Identifier = c.nextID() @@ -95,7 +111,7 @@ func (c *Client) Send(n Notification) error { return err } - c.notifs <- serializedNotif{b: b, id: n.Identifier, n: &n} + c.notifs <- serialized{b: b, id: n.Identifier, n: &n} return nil } @@ -114,7 +130,21 @@ func (c *Client) nextID() uint32 { return c.id } -func (c *Client) reportFailedPush(s serializedNotif, err *Error) { +func (c *Client) connected() { + c.connm.Lock() + defer c.connm.Unlock() + + c.connected = true +} + +func (c *Client) disconnected() { + c.connm.Lock() + defer c.connm.Unlock() + + c.connected = false +} + +func (c *Client) reportFailedPush(s serialized, err *Error) { select { case c.FailedNotifs <- NotificationResult{Notif: *s.n, Err: *err}: default: @@ -125,7 +155,7 @@ func (c *Client) requeue(cursor *list.Element) { // If `cursor` is not nil, this means there are notifications that // need to be delivered (or redelivered) for ; cursor != nil; cursor = cursor.Next() { - if n, ok := cursor.Value.(serializedNotif); ok { + if n, ok := cursor.Value.(serialized); ok { go func() { c.notifs <- n }() } } @@ -136,7 +166,7 @@ func (c *Client) handleError(err *Error, buffer *buffer) *list.Element { for cursor != nil { // Get serialized notification - n, _ := cursor.Value.(serializedNotif) + n, _ := cursor.Value.(serialized) // If the notification, move cursor after the trouble notification if n.id == err.Identifier { @@ -160,13 +190,6 @@ func (c *Client) runLoop() { // APNS connection for { - err := c.Conn.Connect() - if err != nil { - // TODO Probably want to exponentially backoff... - time.Sleep(1 * time.Second) - continue - } - // Start reading errors from APNS errs := readErrs(c.Conn) @@ -175,7 +198,7 @@ func (c *Client) runLoop() { // Connection open, listen for notifs and errors for { var err error - var n serializedNotif + var n serialized // Check for notifications or errors. There is a chance we'll send notifications // if we already have an error since `select` will "pseudorandomly" choose a @@ -205,7 +228,8 @@ func (c *Client) runLoop() { if err == io.EOF { log.Println("EOF trying to write notification") - break + c.connected = false + return } if err != nil { diff --git a/error.go b/error.go index 5425868..3ff4b86 100644 --- a/error.go +++ b/error.go @@ -3,6 +3,11 @@ package apns import ( "bytes" "encoding/binary" + "errors" +) + +const ( + ErrDisconnected = errors.New("disconnected from gateway") ) const ( From d308f5e08c955978aca9821f988a0368ef86caa1 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 16 Dec 2014 00:37:16 -0500 Subject: [PATCH 05/32] Start simplifying the client internals --- client.go | 201 +++++++++++++++++++----------------------------------- error.go | 2 +- 2 files changed, 71 insertions(+), 132 deletions(-) diff --git a/client.go b/client.go index 69a135f..65ad2c4 100644 --- a/client.go +++ b/client.go @@ -4,7 +4,6 @@ import ( "container/list" "crypto/tls" "io" - "log" "sync" ) @@ -39,6 +38,9 @@ type Client struct { notifs chan serialized + buffer *buffer + cursor *list.Element + id uint32 idm sync.Mutex @@ -51,7 +53,9 @@ func newClientWithConn(gw string, conn Conn) Client { Conn: &conn, FailedNotifs: make(chan NotificationResult), notifs: make(chan serialized), - id: 1, + buffer: newBuffer(50), + cursor: nil, + id: 0, idm: sync.Mutex{}, connected: false, connm: sync.Mutex{}, @@ -85,12 +89,20 @@ func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, err } func (c *Client) Connect() error { - err := c.Conn.Connect() - if err != nil { + if err := c.Conn.Connect(); err != nil { + return err + } + + // On connect, requeue any notifications that were + // sent after the error & disconnect. + // http://redth.codes/the-problem-with-apples-push-notification-ser/ + if err := c.requeue(); err != nil { return err } - go c.runLoop() + // Kick off asynchronous error reading + go c.readErrors() + return nil } @@ -100,162 +112,89 @@ func (c *Client) Send(n Notification) error { } // Set identifier if not specified - if n.Identifier == 0 { - n.Identifier = c.nextID() - } else if c.id < n.Identifier { - c.setID(n.Identifier) - } + n.Identifier = c.determineIdentifier(n.Identifier) b, err := n.ToBinary() if err != nil { return err } - c.notifs <- serialized{b: b, id: n.Identifier, n: &n} - return nil -} + // Add to list + c.cursor = c.buffer.Add(n) -func (c *Client) setID(n uint32) { - c.idm.Lock() - defer c.idm.Unlock() + _, err = c.Conn.Write(b) + if err == io.EOF { + c.connected = false + return err + } + + if err != nil { + return err + } - c.id = n + c.cursor = c.cursor.Next() + return nil } -func (c *Client) nextID() uint32 { +func (c *Client) determineIdentifier(n uint32) uint32 { c.idm.Lock() defer c.idm.Unlock() - c.id++ - return c.id -} - -func (c *Client) connected() { - c.connm.Lock() - defer c.connm.Unlock() - - c.connected = true -} - -func (c *Client) disconnected() { - c.connm.Lock() - defer c.connm.Unlock() - - c.connected = false -} - -func (c *Client) reportFailedPush(s serialized, err *Error) { - select { - case c.FailedNotifs <- NotificationResult{Notif: *s.n, Err: *err}: - default: + // If the id passed in is 0, that means it wasn't + // set so get the next ID. Otherwise, set it to that + // identifier. + if n == 0 { + c.id++ + } else { + c.id = n } + + return c.id } -func (c *Client) requeue(cursor *list.Element) { +func (c *Client) requeue() error { // If `cursor` is not nil, this means there are notifications that // need to be delivered (or redelivered) - for ; cursor != nil; cursor = cursor.Next() { - if n, ok := cursor.Value.(serialized); ok { - go func() { c.notifs <- n }() - } - } -} - -func (c *Client) handleError(err *Error, buffer *buffer) *list.Element { - cursor := buffer.Back() - - for cursor != nil { - // Get serialized notification - n, _ := cursor.Value.(serialized) - - // If the notification, move cursor after the trouble notification - if n.id == err.Identifier { - go c.reportFailedPush(n, err) - - next := cursor.Next() - - buffer.Remove(cursor) - return next + for ; c.cursor != nil; c.cursor = c.cursor.Next() { + if s, ok := c.cursor.Value.(serialized); ok { + if err := c.Send(*s.n); err != nil { + return err + } } - - cursor = cursor.Prev() } - return cursor + return nil } -func (c *Client) runLoop() { - sent := newBuffer(50) - cursor := sent.Front() +func (c *Client) readErrors() { + p := make([]byte, 6, 6) - // APNS connection - for { - // Start reading errors from APNS - errs := readErrs(c.Conn) + _, err := c.Conn.Read(p) + // TODO(bw) not sure what to do here. It's unclear what errors + // come out of this and how we handle it. + if err != nil { + return + } - c.requeue(cursor) + e := NewError(p) + cursor := c.buffer.Back() - // Connection open, listen for notifs and errors - for { - var err error - var n serialized + for cursor != nil { + // Get serialized notification + s, _ := cursor.Value.(serialized) - // Check for notifications or errors. There is a chance we'll send notifications - // if we already have an error since `select` will "pseudorandomly" choose a - // ready channels. It turns out to be fine because the connection will already - // be closed and it'll requeue. We could check before we get to this select - // block, but it doesn't seem worth the extra code and complexity. + // If the notification, move cursor after the trouble notification + if s.id == e.Identifier { + // Try to write - skip if no one is reading on the other side select { - case err = <-errs: - case n = <-c.notifs: - } - - // If there is an error we understand, find the notification that failed, - // move the cursor right after it. - if nErr, ok := err.(*Error); ok { - cursor = c.handleError(nErr, sent) - break + case c.FailedNotifs <- NotificationResult{Notif: *s.n, Err: e}: + default: } - if err != nil { - break - } - - // Add to list - cursor = sent.Add(n) - - _, err = c.Conn.Write(n.b) - - if err == io.EOF { - log.Println("EOF trying to write notification") - c.connected = false - return - } - - if err != nil { - log.Println("err writing to apns", err.Error()) - break - } - - cursor = cursor.Next() + c.cursor = cursor.Next() + c.buffer.Remove(cursor) } - } -} - -func readErrs(c *Conn) chan error { - errs := make(chan error) - - go func() { - p := make([]byte, 6, 6) - _, err := c.Read(p) - if err != nil { - errs <- err - return - } - - e := NewError(p) - errs <- &e - }() - return errs + cursor = cursor.Prev() + } } diff --git a/error.go b/error.go index 3ff4b86..3371bea 100644 --- a/error.go +++ b/error.go @@ -6,7 +6,7 @@ import ( "errors" ) -const ( +var ( ErrDisconnected = errors.New("disconnected from gateway") ) From 64b0fc2cc150373a44f7d476c33fb2d04010dbc8 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 16 Dec 2014 00:38:05 -0500 Subject: [PATCH 06/32] Update example --- example/example.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/example/example.go b/example/example.go index 1b670ac..2d9e5c0 100644 --- a/example/example.go +++ b/example/example.go @@ -10,10 +10,14 @@ import ( func main() { c, err := apns.NewClientWithFiles(apns.ProductionGateway, "apns.crt", "apns.key") if err != nil { - log.Fatal("Could not create client", err.Error()) + log.Fatal("Could not create client: ", err.Error()) } - i := 0 + if err := c.Connect(); err != nil { + log.Fatal("Could not create connect: ", err.Error()) + } + + i := 1 for { fmt.Print("Enter ' ': ") From 4a7edba7277c75606e22efe498f6cdd75529b1b9 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 23 Dec 2014 17:46:37 -0500 Subject: [PATCH 07/32] Add connection resource locking --- client.go | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 65ad2c4..d8c2441 100644 --- a/client.go +++ b/client.go @@ -106,6 +106,17 @@ func (c *Client) Connect() error { return nil } +func (c *Client) disconnect() error { + c.connm.Lock() + defer c.connm.Unlock() + + if c.Conn == nil { + return nil + } + + return c.Conn.Close() +} + func (c *Client) Send(n Notification) error { if !c.connected { return ErrDisconnected @@ -122,7 +133,14 @@ func (c *Client) Send(n Notification) error { // Add to list c.cursor = c.buffer.Add(n) - _, err = c.Conn.Write(b) + return c.send(b) +} + +func (c *Client) send(b []byte) error { + c.connm.Lock() + defer c.connm.Unlock() + + _, err := c.Conn.Write(b) if err == io.EOF { c.connected = false return err @@ -179,6 +197,8 @@ func (c *Client) readErrors() { e := NewError(p) cursor := c.buffer.Back() + c.disconnect() + for cursor != nil { // Get serialized notification s, _ := cursor.Value.(serialized) From e8836405420f2661d3c1f95453e90dee47563d38 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 23 Dec 2014 17:55:52 -0500 Subject: [PATCH 08/32] Update notification test and remove old client test file --- client_test.go | 384 ------------------------------------------- notification_test.go | 15 +- 2 files changed, 14 insertions(+), 385 deletions(-) delete mode 100644 client_test.go diff --git a/client_test.go b/client_test.go deleted file mode 100644 index c9dfd47..0000000 --- a/client_test.go +++ /dev/null @@ -1,384 +0,0 @@ -package apns_test - -import ( - "bytes" - "encoding/binary" - "io/ioutil" - "os" - "time" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - "github.com/timehop/apns" -) - -var _ = Describe("Client", func() { - Describe(".NewConn", func() { - Context("bad cert/key pair", func() { - It("should error out", func() { - _, err := apns.NewClient(apns.ProductionGateway, "missing", "missing_also") - Expect(err).NotTo(BeNil()) - }) - }) - - Context("valid cert/key pair", func() { - It("should create a valid client", func() { - c, err := apns.NewClient(apns.ProductionGateway, DummyCert, DummyKey) - Expect(err).To(BeNil()) - Expect(c.Conn).NotTo(BeNil()) - }) - }) - }) - - Describe(".NewConnWithFiles", func() { - Context("missing cert/key pair", func() { - It("should error out", func() { - _, err := apns.NewClientWithFiles(apns.ProductionGateway, "missing", "missing_also") - Expect(err).NotTo(BeNil()) - }) - }) - - Context("valid cert/key pair", func() { - var certFile, keyFile *os.File - - BeforeEach(func() { - certFile, _ = ioutil.TempFile("", "cert.pem") - certFile.Write([]byte(DummyCert)) - certFile.Close() - - keyFile, _ = ioutil.TempFile("", "key.pem") - keyFile.Write([]byte(DummyKey)) - keyFile.Close() - }) - - AfterEach(func() { - if certFile != nil { - os.Remove(certFile.Name()) - } - - if keyFile != nil { - os.Remove(keyFile.Name()) - } - }) - - It("should create a valid client", func() { - c, err := apns.NewClientWithFiles(apns.ProductionGateway, certFile.Name(), keyFile.Name()) - Expect(err).To(BeNil()) - Expect(c.Conn).NotTo(BeNil()) - }) - }) - }) - - Describe("#Send", func() { - Context("simple write", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - Expect(c.Send(apns.Notification{})).To(BeNil()) - - close(mockDone) - close(d) - }) - }) - }) - - Context("simple write with buffer", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - for i := 0; i < 54; i++ { - Expect(c.Send(apns.Notification{})).To(BeNil()) - } - - close(mockDone) - close(d) - }) - }) - }) - - Context("multiple write", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - serverAction{action: readAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - Expect(c.Send(apns.Notification{})).To(BeNil()) - Expect(c.Send(apns.Notification{})).To(BeNil()) - - close(mockDone) - close(d) - }) - }) - }) - - Context("bad push", func() { - n := apns.Notification{Identifier: 9, ID: "some_rando"} - nb, _ := n.ToBinary() - nbcb := make([]byte, len(nb)) - - errPayload := bytes.NewBuffer([]byte{}) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint32(9)) - - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - serverAction{action: readAction, data: nbcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(nb)) - }}, - - // Bad push results in a close - serverAction{action: writeAction, data: errPayload.Bytes()}, - serverAction{action: closeAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - go func() { - n := <-c.FailedNotifs - - Expect(n.Notif.Identifier).To(Equal(uint32(9))) - Expect(n.Notif.ID).To(Equal("some_rando")) - - close(mockDone) - close(d) - }() - - Expect(c.Send(n)).To(BeNil()) - }) - }) - }) - - Context("closed, reconnect", func() { - done := make(chan bool) - - n1 := apns.Notification{Identifier: 1} - n1b, _ := n1.ToBinary() - n1bcb := make([]byte, len(n1b)) - - errPayload := bytes.NewBuffer([]byte{}) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint32(2)) - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - - as := [][]serverAction{ - []serverAction{ - // Write error - serverAction{action: writeAction, data: errPayload.Bytes(), cb: func(a serverAction) { - done <- true - }}, - - // Close on error - serverAction{action: closeAction, cb: func(a serverAction) { - }}, - }, - []serverAction{ - // Reconnect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Reconnected - }}, - - // Read first good notification - serverAction{action: readAction, data: n1bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n1b)) - - close(mockDone) - close(d) - }}, - }, - } - - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - <-done - time.Sleep(5 * time.Millisecond) - - // Good - Expect(c.Send(n1)).To(BeNil()) - }) - }) - }) - - Context("good, close, good, requeue of last good", func() { - closed := make(chan bool) - - n1 := apns.Notification{Identifier: 1} - n2 := apns.Notification{Identifier: 2} - - n1b, _ := n1.ToBinary() - n2b, _ := n2.ToBinary() - - n1bcb := make([]byte, len(n1b)) - n2bcb := make([]byte, len(n2b)) - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - as := [][]serverAction{ - []serverAction{ - // Connect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Handshake - }}, - - // Read first good notification - serverAction{action: readAction, data: n1bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n1b)) - }}, - - // Close on error - serverAction{action: closeAction, cb: func(a serverAction) { - closed <- true - }}, - }, - []serverAction{ - // Reconnect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Reconnected - }}, - - // Requeue - serverAction{action: readAction, data: n2bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n2b)) - - close(mockDone) - close(d) - }}, - }, - } - - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - // Good - Expect(c.Send(n1)).To(BeNil()) - - <-closed - time.Sleep(5 * time.Millisecond) - - // Good - Expect(c.Send(n2)).To(BeNil()) - }) - }) - }) - - Context("good, bad, good, requeue of last good", func() { - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - - n1 := apns.Notification{Identifier: 1} - n2 := apns.Notification{Identifier: 2} - n3 := apns.Notification{Identifier: 3} - - n1b, _ := n1.ToBinary() - n2b, _ := n2.ToBinary() - n3b, _ := n3.ToBinary() - - n1bcb := make([]byte, len(n1b)) - n2bcb := make([]byte, len(n2b)) - n3bcb := make([]byte, len(n3b)) - - errPayload := bytes.NewBuffer([]byte{}) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint32(2)) - - as := [][]serverAction{ - []serverAction{ - // Connect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Handshake - }}, - - // Read first good notification - serverAction{action: readAction, data: n1bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n1b)) - }}, - - // Read bad notification - serverAction{action: readAction, data: n2bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n2b)) - }}, - - // Read second good notification - serverAction{action: readAction, data: n3bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n3b)) - }}, - - // Write error - serverAction{action: writeAction, data: errPayload.Bytes(), cb: func(a serverAction) { - }}, - - // Close on error - serverAction{action: closeAction, cb: func(a serverAction) { - }}, - }, - []serverAction{ - // Reconnect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Reconnected - }}, - - // Requeue - serverAction{action: readAction, data: n3bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n3b)) - - close(mockDone) - close(d) - }}, - }, - } - - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - // Good - Expect(c.Send(n1)).To(BeNil()) - - // Bad - Expect(c.Send(n2)).To(BeNil()) - - // Good - Expect(c.Send(n3)).To(BeNil()) - }) - }) - }) - }) -}) diff --git a/notification_test.go b/notification_test.go index f0bf6cf..da77814 100644 --- a/notification_test.go +++ b/notification_test.go @@ -193,7 +193,20 @@ var _ = Describe("Notifications", func() { Describe("#ToBinary", func() { Context("invalid token format", func() { n := apns.NewNotification() - n.DeviceToken = "totally not a valid token" + n.DeviceToken = "totally not a valid token length" + + It("should return an error", func() { + _, err := n.ToBinary() + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(Equal(apns.ErrInvalidToken)) + }) + + // Expect(err.Error()).To(ContainSubstring("convert token to hex error")) + }) + + Context("non-convertable token", func() { + n := apns.NewNotification() + n.DeviceToken = "123456789012345678901234567890zz123456789012345678901234567890zz" It("should return an error", func() { _, err := n.ToBinary() From 032358131d23a2ebb4d5336bb9010bb91729daa7 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 27 Jan 2015 00:32:45 -0500 Subject: [PATCH 09/32] WIP Convert Conn into an interface and make feedback test --- apns_suite_test.go | 38 +++++ client.go | 4 +- conn.go | 66 +++++---- conn_test.go | 347 +++++++++------------------------------------ feedback.go | 12 +- feedback_test.go | 177 ++++++++++++----------- 6 files changed, 243 insertions(+), 401 deletions(-) diff --git a/apns_suite_test.go b/apns_suite_test.go index b0bcca4..858a5b9 100644 --- a/apns_suite_test.go +++ b/apns_suite_test.go @@ -5,8 +5,46 @@ import ( . "github.com/onsi/gomega" "testing" + "time" ) +type mockConn struct { + connect func() error + read func([]byte) (int, error) + readWithTimeout func([]byte, time.Time) (int, error) +} + +func (m *mockConn) Connect() error { + if m.connect != nil { + return m.connect() + } + + return nil +} + +func (m *mockConn) Read(b []byte) (int, error) { + if m.read != nil { + return m.read(b) + } + return 0, nil +} + +func (m *mockConn) Write([]byte) (int, error) { + return 0, nil +} + +func (m *mockConn) Close() error { + return nil +} + +func (m *mockConn) ReadWithTimeout(b []byte, t time.Time) (int, error) { + if m.readWithTimeout != nil { + return m.readWithTimeout(b, t) + } + + return 0, nil +} + func TestApns(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Apns Suite") diff --git a/client.go b/client.go index d8c2441..e4fac1d 100644 --- a/client.go +++ b/client.go @@ -33,7 +33,7 @@ type serialized struct { } type Client struct { - Conn *Conn + Conn Conn FailedNotifs chan NotificationResult notifs chan serialized @@ -50,7 +50,7 @@ type Client struct { func newClientWithConn(gw string, conn Conn) Client { c := Client{ - Conn: &conn, + Conn: conn, FailedNotifs: make(chan NotificationResult), notifs: make(chan serialized), buffer: newBuffer(50), diff --git a/conn.go b/conn.go index d3aa712..1bb8bab 100644 --- a/conn.go +++ b/conn.go @@ -2,8 +2,10 @@ package apns import ( "crypto/tls" + "io" "net" "strings" + "time" ) const ( @@ -15,9 +17,16 @@ const ( ) // Conn is a wrapper for the actual TLS connections made to Apple -type Conn struct { - NetConn net.Conn - Conf *tls.Config +type Conn interface { + io.ReadWriteCloser + + Connect() error + ReadWithTimeout(p []byte, deadline time.Time) (int, error) +} + +type conn struct { + netConn net.Conn + tls *tls.Config gateway string connected bool @@ -25,19 +34,20 @@ type Conn struct { func NewConnWithCert(gw string, cert tls.Certificate) Conn { gatewayParts := strings.Split(gw, ":") - conf := tls.Config{ - Certificates: []tls.Certificate{cert}, - ServerName: gatewayParts[0], + tls := tls.Config{ + Certificates: []tls.Certificate{cert}, + ServerName: gatewayParts[0], + InsecureSkipVerify: true, } - return Conn{gateway: gw, Conf: &conf} + return &conn{gateway: gw, tls: &tls} } // NewConnWithFiles creates a new Conn from certificate and key in the specified files func NewConn(gw string, crt string, key string) (Conn, error) { cert, err := tls.X509KeyPair([]byte(crt), []byte(key)) if err != nil { - return Conn{}, err + return &conn{}, err } return NewConnWithCert(gw, cert), nil @@ -47,49 +57,51 @@ func NewConn(gw string, crt string, key string) (Conn, error) { func NewConnWithFiles(gw string, certFile string, keyFile string) (Conn, error) { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { - return Conn{}, err + return &conn{}, err } return NewConnWithCert(gw, cert), nil } // Connect actually creates the TLS connection -func (c *Conn) Connect() error { +func (c *conn) Connect() error { // Make sure the existing connection is closed - if c.NetConn != nil { - c.NetConn.Close() - } - - conn, err := net.Dial("tcp", c.gateway) - if err != nil { - return err + if c.netConn != nil { + c.netConn.Close() } - tlsConn := tls.Client(conn, c.Conf) - err = tlsConn.Handshake() + tlsConn, err := tls.Dial("tcp", c.gateway, c.tls) if err != nil { return err } - c.NetConn = tlsConn + c.netConn = tlsConn return nil } -func (c *Conn) Close() error { - if c.NetConn != nil { - return c.NetConn.Close() +func (c *conn) Close() error { + if c.netConn != nil { + return c.netConn.Close() } return nil } // Read reads data from the connection -func (c *Conn) Read(p []byte) (int, error) { - i, err := c.NetConn.Read(p) +func (c *conn) Read(p []byte) (int, error) { + i, err := c.netConn.Read(p) + return i, err +} + +// ReadWithTimeout reads data from the connection and returns an error +// after duration +func (c *conn) ReadWithTimeout(p []byte, deadline time.Time) (int, error) { + c.netConn.SetReadDeadline(deadline) + i, err := c.netConn.Read(p) return i, err } // Write writes data from the connection -func (c *Conn) Write(p []byte) (int, error) { - return c.NetConn.Write(p) +func (c *conn) Write(p []byte) (int, error) { + return c.netConn.Write(p) } diff --git a/conn_test.go b/conn_test.go index e910e6c..4bee29c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,232 +1,15 @@ package apns_test import ( - "bytes" - "crypto/tls" - "fmt" - "io" "io/ioutil" - "log" "net" "os" - "strings" - "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/timehop/apns" + "github.com/timehop/tcptest" ) -var DummyCert = `-----BEGIN CERTIFICATE----- -MIIC9TCCAd+gAwIBAgIQf3bEgFWUb+q6eK5ySkV/gjALBgkqhkiG9w0BAQUwEjEQ -MA4GA1UEChMHQWNtZSBDbzAeFw0xNDA2MzAwNDI5MDhaFw0xNTA2MzAwNDI5MDha -MBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK -AoIBAQDhAgWrrFZBtCfVEPg1tSIr9fuSUoeundb556IUr9uOmOHaYK7r3/I43acw -bVIfaenFxwUUf8YakQzTjOa5qSfK/Eylyw2ezBJtNUEqcHw0f+y66+jJbZa4clPa -tL6ezaMS/syXPpvNU8+16jdVdTJzqdBdSGAZMOCeumUWDNdlfBmHPVq1JMy0uGmO -XDoZK2Ir0/3LUfjk9R2wdm1VLrJAml7F0L0FhBHHXgHOSFM2ixjGflffaiuTCxhW -1z1NTo9XjWUQh2iM9Udf+xVnJLGLZ0EMFr2qihuK604Fp4SlNHEF+UWUn+j0PYo+ -LbzM9oKJcdVD0XI36vrn3rGPHO9vAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIAoDAT -BgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuCCWxv -Y2FsaG9zdDALBgkqhkiG9w0BAQUDggEBAGJ/3I4KKlbEwLAC5ut4ZZ9V8WF4sHkI -Lj7e4vx2pPi6hf9miV1ff01NrpfUna7flwL9yD7Ybl7jRRIB4rIcKk+U5djGsT3H -ScGkbIMKrr08drWw1g4JU6PBH7xTfzGxNRERrnmrbJV0jCo9Tt8i53IpPtp6Z2Q1 -8ydtPhU+Bpe2YoNr1w1fSV1JHXqjKV8RlGkCNSi4ozPOO8RbAYnBT3d9XSGoX//q -RGJUf3wC/rCxJkN63Moxuy3vxV2TmiqccHOrXJSJ8P/4PpPV/xuBk5k4HS1Nfmew -d9WHHn6bMJE9arVvWAiu9teCadVffuS2cl2cicN4XB6Ui0aDqhG2Exw= ------END CERTIFICATE-----` - -var DummyKey = `-----BEGIN RSA PRIVATE KEY----- -MIIEpAIBAAKCAQEA4QIFq6xWQbQn1RD4NbUiK/X7klKHrp3W+eeiFK/bjpjh2mCu -69/yON2nMG1SH2npxccFFH/GGpEM04zmuaknyvxMpcsNnswSbTVBKnB8NH/suuvo -yW2WuHJT2rS+ns2jEv7Mlz6bzVPPteo3VXUyc6nQXUhgGTDgnrplFgzXZXwZhz1a -tSTMtLhpjlw6GStiK9P9y1H45PUdsHZtVS6yQJpexdC9BYQRx14BzkhTNosYxn5X -32orkwsYVtc9TU6PV41lEIdojPVHX/sVZySxi2dBDBa9qoobiutOBaeEpTRxBflF -lJ/o9D2KPi28zPaCiXHVQ9FyN+r6596xjxzvbwIDAQABAoIBAFzW+cIA5MJNdFX8 -n32BlGzxHPEd7nAFHmuUwJKqkPwAZsg1NleK2qXOByr7IHRnvhZl7Nmtcu8JRHKR -Y63ddtbRTUrnQmJwL3YyEAZTzVvYILRrnGxoNFU8jw7hnvllPdEbow0QvzZ0S3Lz -BgvTxJJm0dt7fnNGcJftrsHvYHy1dptaR4hPv0xV5G7RPrbTl94llKfi745tp5Wd -xGpnjcBXoAnzCVRij1tHfSYubRJ2MJV0kzG3oVdRV2P/zWaout8BlhLCURv4sRUX -7FfCNa/z+G6AlROjCKJUP9YIUbxBEa/aP8YlSiyLRi1jFbMWcnKWQUdqS19m73Ap -a1LJFPECgYEA+Ve5DegcrWnUb2HsHD38HlmEg6S+/jg2P4TsuLZBtvO4/vzRx/qq -pwuuMm2CsvXr4nVmMEsMlSzYdsnaXIlWqyVDCOwIWR5VYT2GDWqQLaIXPlFaISzN -27tHd64KUtR1fMJUwQVK/MUORUbpYoAnSIil2SlYkWUhF024fNP8CxcCgYEA5wP4 -HLiqU2rqe7vSAF/8fHwPleTzuCfMCVZm0aegUzQQQtklZoVE/BBwEGHdXflq1veq -pHeC8bNR4BF6ZgeSWgbLVF3msquy47QeNElHA2muJd3qmNWz4LXo1Pxb8KXcnXri -QZ+r3Y8obWTFQYq7gGQGPLXGTV3bhLGIyrT4lWkCgYAgZ2MYSJL5gmhmNT6fCPsr -4oxTI2Ti2uFJ7fdppd3ybcgb8zU8HPpyjRUNXqf+o/EM1B78pbQz6skS3vau0fZe -dZA5p5sKIeQMqBc0xSWJmKgWpDHnX9A8/yCxj/+tdgjytrqW/x4YrW9GV4nbEDaK -uZ98EmB9PLxJMAOKzW3S7wKBgQDD4PCy4b3CR2iVC9dva/P5VXQdo+knX884p6M8 -58YgZofXNqnouN2aYRG0QlbiBMcbiRqOo6tK58JnnEpNUuQ8I4Cqg4hGPSHMwv/N -U8i70xLPltABUUpZIcVPOr92WBytBvHrtMiUb3tW7lf3T/vWTHmhZnvDQ+8LH0Ge -pz4T6QKBgQCoBJKOd781IQmT6i5hHSYJlsP6ymaaaQniJPVpnci/jf8+2QtponQY -scgnaBLBasLQ6GfKSRtcyidEi9wwxpVj0tw2p567jeNcIveD0TOYFf0RHEfrs+D4 -VdRgai/v2NbFZLDnzeGVuYypXu6R78isJfHtz/a0aEave8yB3CRiDw== ------END RSA PRIVATE KEY-----` - -// To be able to run in parallel -var mockPort = 50000 - -// Mock Addr -type mockAddr struct { -} - -func (m mockAddr) Network() string { - return "localhost:56789" -} - -func (m mockAddr) String() string { - return "localhost:56789" -} - -// Mock TLS connection -type mockTLSNetConn struct { - bb *bytes.Buffer - err error -} - -func (t mockTLSNetConn) Read(p []byte) (int, error) { - r := bytes.NewReader(t.bb.Bytes()) - return r.Read(p) -} - -func (t mockTLSNetConn) Write(p []byte) (int, error) { - return t.bb.Write(p) -} - -func (t mockTLSNetConn) Close() error { - return t.err -} - -func (m mockTLSNetConn) LocalAddr() net.Addr { - return mockAddr{} -} - -func (m mockTLSNetConn) RemoteAddr() net.Addr { - return mockAddr{} -} - -func (m mockTLSNetConn) SetDeadline(t time.Time) error { - return nil -} - -func (m mockTLSNetConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (m mockTLSNetConn) SetWriteDeadline(t time.Time) error { - return nil -} - -type serverAction struct { - action string - data []byte - cb func(s serverAction) -} - -const ( - readAction = "read" - writeAction = "write" - closeAction = "close" -) - -type mockTLSServer struct { - Port int - Server net.Listener - ConnectionActionGroups [][]serverAction -} - -func (m *mockTLSServer) portStr() string { - if m.Port == 0 { - mockPort = mockPort + 1 - m.Port = mockPort - } - - return fmt.Sprint(m.Port) -} - -func (m *mockTLSServer) Address() string { - return "localhost:" + m.portStr() -} - -func (m *mockTLSServer) start() { - cert, err := tls.X509KeyPair([]byte(DummyCert), []byte(DummyKey)) - if err != nil { - log.Panic(err) - } - - config := tls.Config{Certificates: []tls.Certificate{cert}, ClientAuth: tls.RequireAnyClientCert} - - m.Server, err = tls.Listen("tcp", "localhost:"+m.portStr(), &config) - go func() { - for i := 0; i < len(m.ConnectionActionGroups); i++ { - g := m.ConnectionActionGroups[i] - - // Wait for a connection. - conn, err := m.Server.Accept() - if err != nil { - if strings.Contains(err.Error(), "use of closed network connection") { - return - } else { - log.Fatal(err) - } - } - // Handle the connection in a new goroutine. - // The loop then returns to accepting, so that - // multiple connections may be served concurrently. - go func(c net.Conn) { - for j := 0; j < len(g); j++ { - a := g[j] - switch a.action { - case readAction: - c.Read(a.data) - case writeAction: - c.Write(a.data) - case closeAction: - c.Close() - - if a.cb != nil { - a.cb(a) - } - return - } - - if a.cb != nil { - a.cb(a) - } - } - }(conn) - } - - // No more connection action groups - }() -} - -func (m *mockTLSServer) stop() { - if m.Server != nil { - m.Server.Close() - } -} - -var withMockServer = func(as [][]serverAction, cb func(s *mockTLSServer)) { - d := make(chan interface{}) - withMockServerAsync(as, d, func(s *mockTLSServer) { - cb(s) - close(d) - }) -} - -var withMockServerAsync = func(as [][]serverAction, d chan interface{}, cb func(s *mockTLSServer)) { - s := &mockTLSServer{} - s.ConnectionActionGroups = as - - s.start() - - cb(s) - - <-d - s.stop() -} - // Tests var _ = Describe("Conn", func() { Describe(".NewConn", func() { @@ -239,7 +22,7 @@ var _ = Describe("Conn", func() { Context("valid key/cert pair", func() { It("should not return an error", func() { - _, err := apns.NewConn(apns.SandboxGateway, DummyCert, DummyKey) + _, err := apns.NewConn(apns.SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) }) }) @@ -259,11 +42,11 @@ var _ = Describe("Conn", func() { BeforeEach(func() { certFile, _ = ioutil.TempFile("", "cert.pem") - certFile.Write([]byte(DummyCert)) + certFile.Write([]byte(tcptest.LocalhostCert)) certFile.Close() keyFile, _ = ioutil.TempFile("", "key.pem") - keyFile.Write([]byte(DummyKey)) + keyFile.Write([]byte(tcptest.LocalhostKey)) keyFile.Close() }) @@ -295,65 +78,71 @@ var _ = Describe("Conn", func() { }) Context("server up", func() { - as := [][]serverAction{[]serverAction{serverAction{action: readAction, data: []byte{}}}} - Context("with untrusted certs", func() { It("should return an error", func(d Done) { - withMockServer(as, func(s *mockTLSServer) { - conn, _ := apns.NewConn(s.Address(), DummyCert, DummyKey) - err := conn.Connect() - Expect(err).NotTo(BeNil()) + s := tcptest.NewTLSServer(func(c net.Conn) {}) + defer s.Close() - close(d) - }) + conn, err := apns.NewConn(s.Addr, "not trusted", "not even a little") + Expect(err).NotTo(BeNil()) + + err = conn.Connect() + Expect(err).NotTo(BeNil()) + + close(d) }) }) Context("trusting the certs", func() { It("should not return an error", func(d Done) { - withMockServer(as, func(s *mockTLSServer) { - conn, _ := apns.NewConn(s.Address(), DummyCert, DummyKey) - conn.Conf.InsecureSkipVerify = true + s := tcptest.NewUnstartedServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect + }) - err := conn.Connect() - Expect(err).To(BeNil()) + s.StartTLS() + defer s.Close() - close(d) - }) + conn, err := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = conn.Connect() + Expect(err).To(BeNil()) + + close(d) }) }) Context("with existing connection", func() { It("should not return an error", func(d Done) { - as = [][]serverAction{ - []serverAction{serverAction{action: readAction, data: []byte{}}}, - []serverAction{serverAction{action: readAction, data: []byte{}}}, - } + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect + }) + defer s.Close() - withMockServer(as, func(s *mockTLSServer) { - conn, _ := apns.NewConn(s.Address(), DummyCert, DummyKey) - conn.Conf.InsecureSkipVerify = true + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - conn.Connect() + conn.Connect() - err := conn.Connect() - Expect(err).To(BeNil()) + err := conn.Connect() + Expect(err).To(BeNil()) - close(d) - }) + close(d) }) }) }) }) Describe("#Read", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte("hello!"))} - - pp := make([]byte, 6) - bytes.NewReader(rwc.bb.Bytes()).Read(pp) + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte("hello!")) + }) + defer s.Close() - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() It("should read out 'hello!'", func() { p := make([]byte, 6) @@ -364,47 +153,47 @@ var _ = Describe("Conn", func() { }) Describe("#Write", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte{})} + It("should read out 'hello!'", func(d Done) { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect + + b := make([]byte, 6) + c.Read(b) + + Expect(string(b)).To(Equal("hello!")) + close(d) + }) + defer s.Close() - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() - It("should write out 'world!'", func() { - conn.Write([]byte("world!")) - Expect(rwc.bb.String()).To(Equal("world!")) + conn.Write([]byte("hello!")) }) }) Describe("#Close", func() { Context("with connection", func() { Context("no error", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte{})} - - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc - It("should return no error", func() { - Expect(rwc.Close()).To(BeNil()) - }) - }) - - Context("with error", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte{})} - - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect + }) + defer s.Close() - rwc.err = io.EOF - It("should return that error", func() { - Expect(rwc.Close()).To(Equal(io.EOF)) + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() + Expect(conn.Close()).To(BeNil()) }) }) }) Context("without connection", func() { - c, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) It("should not return an error", func() { - Expect(c.Close()).To(BeNil()) + conn, _ := apns.NewConn("localhost:12345", string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(conn.Close()).To(BeNil()) }) }) }) diff --git a/feedback.go b/feedback.go index 488bf1b..a0093b1 100644 --- a/feedback.go +++ b/feedback.go @@ -9,7 +9,7 @@ import ( ) type Feedback struct { - Conn *Conn + Conn Conn } type FeedbackTuple struct { @@ -40,7 +40,7 @@ func feedbackTupleFromBytes(b []byte) FeedbackTuple { func NewFeedbackWithCert(gw string, cert tls.Certificate) Feedback { conn := NewConnWithCert(gw, cert) - return Feedback{Conn: &conn} + return Feedback{Conn: conn} } func NewFeedback(gw string, cert string, key string) (Feedback, error) { @@ -49,7 +49,7 @@ func NewFeedback(gw string, cert string, key string) (Feedback, error) { return Feedback{}, err } - return Feedback{Conn: &conn}, nil + return Feedback{Conn: conn}, nil } func NewFeedbackWithFiles(gw string, certFile string, keyFile string) (Feedback, error) { @@ -58,7 +58,7 @@ func NewFeedbackWithFiles(gw string, certFile string, keyFile string) (Feedback, return Feedback{}, err } - return Feedback{Conn: &conn}, nil + return Feedback{Conn: conn}, nil } // Receive returns a read only channel for APNs feedback. The returned channel @@ -80,9 +80,7 @@ func (f Feedback) receive(fc chan FeedbackTuple) { for { b := make([]byte, 38) - f.Conn.NetConn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - - _, err := f.Conn.Read(b) + _, err := f.Conn.ReadWithTimeout(b, time.Now().Add(100*time.Millisecond)) if err != nil { close(fc) return diff --git a/feedback_test.go b/feedback_test.go index 29978b4..8cd909b 100644 --- a/feedback_test.go +++ b/feedback_test.go @@ -4,12 +4,16 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "fmt" + "io" "io/ioutil" + "net" "os" "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/timehop/apns" + "github.com/timehop/tcptest" ) var _ = Describe("Feedback", func() { @@ -23,7 +27,7 @@ var _ = Describe("Feedback", func() { Context("valid cert/key pair", func() { It("should create a valid client", func() { - _, err := apns.NewFeedback(apns.ProductionGateway, DummyCert, DummyKey) + _, err := apns.NewFeedback(apns.SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) }) }) @@ -42,11 +46,11 @@ var _ = Describe("Feedback", func() { BeforeEach(func() { certFile, _ = ioutil.TempFile("", "cert.pem") - certFile.Write([]byte(DummyCert)) + certFile.Write([]byte(tcptest.LocalhostCert)) certFile.Close() keyFile, _ = ioutil.TempFile("", "key.pem") - keyFile.Write([]byte(DummyKey)) + keyFile.Write([]byte(tcptest.LocalhostKey)) keyFile.Close() }) @@ -70,11 +74,13 @@ var _ = Describe("Feedback", func() { Describe("#Receive", func() { Context("could not connect", func() { It("should not receive anything", func() { - s := &mockTLSServer{} - - f, _ := apns.NewFeedback(s.Address(), DummyCert, DummyKey) - f.Conn.Conf.InsecureSkipVerify = true + m := mockConn{ + connect: func() error { + return io.EOF + }, + } + f := apns.Feedback{Conn: &m} c := f.Receive() r := 0 @@ -87,89 +93,88 @@ var _ = Describe("Feedback", func() { }) Context("times out", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - }, - } - - withMockServer(as, func(s *mockTLSServer) { - f, _ := apns.NewFeedback(s.Address(), DummyCert, DummyKey) - f.Conn.Conf.InsecureSkipVerify = true - - It("should not receive anything", func() { - c := f.Receive() - - r := 0 - for _ = range c { - r += 1 - } - - Expect(r).To(Equal(0)) - }) + It("should not receive anything", func() { + m := mockConn{ + readWithTimeout: func(b []byte, t time.Time) (int, error) { + return 0, net.UnknownNetworkError("") + }, + } + + f := apns.Feedback{Conn: &m} + c := f.Receive() + + r := 0 + for _ = range c { + r += 1 + } + + Expect(r).To(Equal(0)) }) }) + }) - Context("with feedback", func() { - f1 := bytes.NewBuffer([]byte{}) - f2 := bytes.NewBuffer([]byte{}) - f3 := bytes.NewBuffer([]byte{}) - - // The final token strings - t1 := "00a18269661e9406aea59a5620b05c7c0e371574fa6f251951de8d7a5a292535" - t2 := "00a1a4b7294fcfbc5293f63d4298fcecd9c20a893befd45adceead5fc92d3319" - t3 := "00a1b7893d5e85eb8bb7bf0846b464d075248555118ae893b06e96cfb8d678e3" - - bt1, _ := hex.DecodeString(t1) - bt2, _ := hex.DecodeString(t2) - bt3, _ := hex.DecodeString(t3) - - binary.Write(f1, binary.BigEndian, uint32(1404358249)) - binary.Write(f1, binary.BigEndian, uint16(len(bt1))) - binary.Write(f1, binary.BigEndian, bt1) - - binary.Write(f2, binary.BigEndian, uint32(1404352249)) - binary.Write(f2, binary.BigEndian, uint16(len(bt2))) - binary.Write(f2, binary.BigEndian, bt2) - - binary.Write(f3, binary.BigEndian, uint32(1394352249)) - binary.Write(f3, binary.BigEndian, uint16(len(bt3))) - binary.Write(f3, binary.BigEndian, bt3) - - as := [][]serverAction{ - []serverAction{ - serverAction{action: writeAction, data: f1.Bytes()}, - serverAction{action: writeAction, data: f2.Bytes()}, - serverAction{action: writeAction, data: f3.Bytes()}, - }, - } - - It("should receive feedback", func(d Done) { - withMockServer(as, func(s *mockTLSServer) { - f, _ := apns.NewFeedback(s.Address(), DummyCert, DummyKey) - f.Conn.Conf.InsecureSkipVerify = true - - c := f.Receive() - - r1 := <-c - Expect(r1.Timestamp).To(Equal(time.Unix(1404358249, 0))) - Expect(r1.TokenLength).To(Equal(uint16(len(bt1)))) - Expect(r1.DeviceToken).To(Equal(t1)) - - r2 := <-c - Expect(r2.Timestamp).To(Equal(time.Unix(1404352249, 0))) - Expect(r2.TokenLength).To(Equal(uint16(len(bt2)))) - Expect(r2.DeviceToken).To(Equal(t2)) - - r3 := <-c - Expect(r3.Timestamp).To(Equal(time.Unix(1394352249, 0))) - Expect(r3.TokenLength).To(Equal(uint16(len(bt3)))) - Expect(r3.DeviceToken).To(Equal(t3)) - - <-c - close(d) - }) + Context("with feedback", func() { + f1 := bytes.NewBuffer([]byte{}) + f2 := bytes.NewBuffer([]byte{}) + f3 := bytes.NewBuffer([]byte{}) + + // The final token strings + t1 := "00a18269661e9406aea59a5620b05c7c0e371574fa6f251951de8d7a5a292535" + t2 := "00a1a4b7294fcfbc5293f63d4298fcecd9c20a893befd45adceead5fc92d3319" + t3 := "00a1b7893d5e85eb8bb7bf0846b464d075248555118ae893b06e96cfb8d678e3" + + bt1, _ := hex.DecodeString(t1) + bt2, _ := hex.DecodeString(t2) + bt3, _ := hex.DecodeString(t3) + + binary.Write(f1, binary.BigEndian, uint32(1404358249)) + binary.Write(f1, binary.BigEndian, uint16(len(bt1))) + binary.Write(f1, binary.BigEndian, bt1) + + binary.Write(f2, binary.BigEndian, uint32(1404352249)) + binary.Write(f2, binary.BigEndian, uint16(len(bt2))) + binary.Write(f2, binary.BigEndian, bt2) + + binary.Write(f3, binary.BigEndian, uint32(1394352249)) + fmt.Println("f3 bytes", f3) + + binary.Write(f3, binary.BigEndian, uint16(len(bt3))) + binary.Write(f3, binary.BigEndian, bt3) + + It("should receive feedback", func(d Done) { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write(f1.Bytes()) + c.Write(f2.Bytes()) + c.Write(f3.Bytes()) + + // TODO(bw) figure out why we need this + c.Write([]byte{0}) + c.Close() }) + defer s.Close() + + f, err := apns.NewFeedback(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + c := f.Receive() + + r1 := <-c + Expect(r1.Timestamp.Unix()).To(Equal(int64(1404358249))) + Expect(r1.TokenLength).To(Equal(uint16(len(bt1)))) + Expect(r1.DeviceToken).To(Equal(t1)) + + r2 := <-c + Expect(r2.Timestamp.Unix()).To(Equal(int64(1404352249))) + Expect(r2.TokenLength).To(Equal(uint16(len(bt2)))) + Expect(r2.DeviceToken).To(Equal(t2)) + + r3 := <-c + Expect(r3.Timestamp.Unix()).To(Equal(int64(1394352249))) + Expect(r3.TokenLength).To(Equal(uint16(len(bt3)))) + Expect(r3.DeviceToken).To(Equal(t3)) + + <-c + close(d) }) }) }) From e2b7609fda879bbfe18f5b918c2149f9a4102132 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 28 Jan 2015 19:24:44 -0500 Subject: [PATCH 10/32] Lengthen timeouts for real tcp conn negotiation stuff --- conn.go | 6 ++-- conn_test.go | 85 ++++++++++++++++++++++++------------------------ feedback_test.go | 8 ++--- 3 files changed, 47 insertions(+), 52 deletions(-) diff --git a/conn.go b/conn.go index 1bb8bab..b14f8b8 100644 --- a/conn.go +++ b/conn.go @@ -89,16 +89,14 @@ func (c *conn) Close() error { // Read reads data from the connection func (c *conn) Read(p []byte) (int, error) { - i, err := c.netConn.Read(p) - return i, err + return c.netConn.Read(p) } // ReadWithTimeout reads data from the connection and returns an error // after duration func (c *conn) ReadWithTimeout(p []byte, deadline time.Time) (int, error) { c.netConn.SetReadDeadline(deadline) - i, err := c.netConn.Read(p) - return i, err + return c.netConn.Read(p) } // Write writes data from the connection diff --git a/conn_test.go b/conn_test.go index 4bee29c..44ebf23 100644 --- a/conn_test.go +++ b/conn_test.go @@ -129,72 +129,71 @@ var _ = Describe("Conn", func() { Expect(err).To(BeNil()) close(d) - }) + }, 10) }) }) }) Describe("#Read", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - defer c.Close() - c.Write([]byte("hello!")) - }) - defer s.Close() + It("should read out 'hello!'", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte("hello!")) + }) + defer s.Close() - conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - conn.Connect() + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() - It("should read out 'hello!'", func() { p := make([]byte, 6) conn.Read(p) Expect(p).To(Equal([]byte("hello!"))) }) }) +}) - Describe("#Write", func() { - It("should read out 'hello!'", func(d Done) { - s := tcptest.NewTLSServer(func(c net.Conn) { - defer c.Close() - c.Write([]byte{}) // Connect - - b := make([]byte, 6) - c.Read(b) - - Expect(string(b)).To(Equal("hello!")) - close(d) - }) - defer s.Close() +var _ = Describe("#Write", func() { + It("should read out 'hello!'", func(d Done) { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect - conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - conn.Connect() + b := make([]byte, 6) + c.Read(b) - conn.Write([]byte("hello!")) + Expect(string(b)).To(Equal("hello!")) + close(d) }) - }) - Describe("#Close", func() { - Context("with connection", func() { - Context("no error", func() { - It("should return no error", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - defer c.Close() - c.Write([]byte{}) // Connect - }) - defer s.Close() + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() - conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - conn.Connect() - Expect(conn.Close()).To(BeNil()) + conn.Write([]byte("hello!")) + }, 10) +}) + +var _ = Describe("#Close", func() { + Context("with connection", func() { + Context("no error", func() { + It("should return no error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect }) - }) - }) + defer s.Close() - Context("without connection", func() { - It("should not return an error", func() { - conn, _ := apns.NewConn("localhost:12345", string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() Expect(conn.Close()).To(BeNil()) }) }) }) + + Context("without connection", func() { + It("should not return an error", func() { + conn, _ := apns.NewConn("localhost:12345", string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(conn.Close()).To(BeNil()) + }) + }) }) diff --git a/feedback_test.go b/feedback_test.go index 8cd909b..96dadb4 100644 --- a/feedback_test.go +++ b/feedback_test.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/binary" "encoding/hex" - "fmt" "io" "io/ioutil" "net" @@ -136,8 +135,6 @@ var _ = Describe("Feedback", func() { binary.Write(f2, binary.BigEndian, bt2) binary.Write(f3, binary.BigEndian, uint32(1394352249)) - fmt.Println("f3 bytes", f3) - binary.Write(f3, binary.BigEndian, uint16(len(bt3))) binary.Write(f3, binary.BigEndian, bt3) @@ -147,8 +144,9 @@ var _ = Describe("Feedback", func() { c.Write(f2.Bytes()) c.Write(f3.Bytes()) - // TODO(bw) figure out why we need this + // TODO(bw) this doesn't seem right c.Write([]byte{0}) + c.Close() }) defer s.Close() @@ -175,6 +173,6 @@ var _ = Describe("Feedback", func() { <-c close(d) - }) + }, 10) }) }) From abc584f4df7b7299101431c0127bff5ca9c15b9f Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 28 Jan 2015 19:27:47 -0500 Subject: [PATCH 11/32] Add race detection --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 18c443a..687544f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ before_script: - go get github.com/onsi/gomega - go get code.google.com/p/go.tools/cmd/cover - go install github.com/onsi/ginkgo/ginkgo -script: ginkgo -r --skipMeasurements --cover --trace +script: ginkgo -r --skipMeasurements --cover --trace --race env: global: - PATH=$HOME/gopath/bin:$PATH From 1a7760cdb6ed8592dd0bf94bbc32d1a4596cb421 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 28 Jan 2015 19:32:04 -0500 Subject: [PATCH 12/32] Extend timeout --- conn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn_test.go b/conn_test.go index 44ebf23..d388a6d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -110,7 +110,7 @@ var _ = Describe("Conn", func() { Expect(err).To(BeNil()) close(d) - }) + }, 10) }) Context("with existing connection", func() { From 2c85a0a99d2e26e015d5414c892cfa18edd3e2b4 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Thu, 12 Feb 2015 22:13:31 -0500 Subject: [PATCH 13/32] Start filling out client test --- client.go | 9 ++++ client_test.go | 125 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 client_test.go diff --git a/client.go b/client.go index e4fac1d..c52f623 100644 --- a/client.go +++ b/client.go @@ -93,6 +93,8 @@ func (c *Client) Connect() error { return err } + c.setConnected(true) + // On connect, requeue any notifications that were // sent after the error & disconnect. // http://redth.codes/the-problem-with-apples-push-notification-ser/ @@ -218,3 +220,10 @@ func (c *Client) readErrors() { cursor = cursor.Prev() } } + +func (c *Client) setConnected(connected bool) { + c.connm.Lock() + defer c.connm.Unlock() + + c.connected = true +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..618dd2b --- /dev/null +++ b/client_test.go @@ -0,0 +1,125 @@ +package apns_test + +import ( + "io/ioutil" + "net" + "os" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/timehop/apns" + "github.com/timehop/tcptest" +) + +var _ = Describe("Client", func() { + Describe(".NewClient", func() { + Context("bad cert/key pair", func() { + It("should error out", func() { + _, err := apns.NewClient(apns.ProductionGateway, "missing", "missing_also") + Expect(err).NotTo(BeNil()) + }) + }) + + Context("valid cert/key pair", func() { + It("should create a valid client", func() { + _, err := apns.NewClient(apns.SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + }) + }) + }) + + Describe(".NewClientWithFiles", func() { + Context("missing cert/key pair", func() { + It("should error out", func() { + _, err := apns.NewClientWithFiles(apns.ProductionGateway, "missing", "missing_also") + Expect(err).NotTo(BeNil()) + }) + }) + + Context("valid cert/key pair", func() { + var certFile, keyFile *os.File + + BeforeEach(func() { + certFile, _ = ioutil.TempFile("", "cert.pem") + certFile.Write([]byte(tcptest.LocalhostCert)) + certFile.Close() + + keyFile, _ = ioutil.TempFile("", "key.pem") + keyFile.Write([]byte(tcptest.LocalhostKey)) + keyFile.Close() + }) + + AfterEach(func() { + if certFile != nil { + os.Remove(certFile.Name()) + } + + if keyFile != nil { + os.Remove(keyFile.Name()) + } + }) + + It("should create a valid client", func() { + _, err := apns.NewClientWithFiles(apns.ProductionGateway, certFile.Name(), keyFile.Name()) + Expect(err).To(BeNil()) + }) + }) + }) + + Describe("Connect", func() { + It("should not return an error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write([]byte{0}) + c.Close() + }) + defer s.Close() + + c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = c.Connect() + Expect(err).To(BeNil()) + }) + }) + + Describe("Send", func() { + Context("valid push", func() { + It("should not return an error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write([]byte{0}) + c.Close() + }) + defer s.Close() + + c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = c.Connect() + Expect(err).To(BeNil()) + + err = c.Send(apns.Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + }) + }) + + Context("invalid notification", func() { + It("should not return an error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write([]byte{0}) + c.Close() + }) + defer s.Close() + + c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = c.Connect() + Expect(err).To(BeNil()) + + err = c.Send(apns.Notification{DeviceToken: "lol"}) + Expect(err).NotTo(BeNil()) + }) + }) + }) +}) From 9e1b5bb2e51be9716d72c6188617a6dad7c34c57 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Thu, 12 Feb 2015 22:18:52 -0500 Subject: [PATCH 14/32] Remove extraneous function --- client.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/client.go b/client.go index c52f623..8be71c1 100644 --- a/client.go +++ b/client.go @@ -93,7 +93,7 @@ func (c *Client) Connect() error { return err } - c.setConnected(true) + c.connected = true // On connect, requeue any notifications that were // sent after the error & disconnect. @@ -220,10 +220,3 @@ func (c *Client) readErrors() { cursor = cursor.Prev() } } - -func (c *Client) setConnected(connected bool) { - c.connm.Lock() - defer c.connm.Unlock() - - c.connected = true -} From 4fa74c969def1fb7f4082214e537c7105d6fc336 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Thu, 12 Feb 2015 22:46:05 -0500 Subject: [PATCH 15/32] Adding more client tests --- client.go | 13 +++++------ client_test.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/client.go b/client.go index 8be71c1..8f1c157 100644 --- a/client.go +++ b/client.go @@ -36,8 +36,6 @@ type Client struct { Conn Conn FailedNotifs chan NotificationResult - notifs chan serialized - buffer *buffer cursor *list.Element @@ -52,7 +50,6 @@ func newClientWithConn(gw string, conn Conn) Client { c := Client{ Conn: conn, FailedNotifs: make(chan NotificationResult), - notifs: make(chan serialized), buffer: newBuffer(50), cursor: nil, id: 0, @@ -197,19 +194,19 @@ func (c *Client) readErrors() { } e := NewError(p) - cursor := c.buffer.Back() - c.disconnect() + cursor := c.buffer.Back() + for cursor != nil { // Get serialized notification - s, _ := cursor.Value.(serialized) + n, _ := cursor.Value.(Notification) // If the notification, move cursor after the trouble notification - if s.id == e.Identifier { + if n.Identifier == e.Identifier { // Try to write - skip if no one is reading on the other side select { - case c.FailedNotifs <- NotificationResult{Notif: *s.n, Err: e}: + case c.FailedNotifs <- NotificationResult{Notif: n, Err: e}: default: } diff --git a/client_test.go b/client_test.go index 618dd2b..5a6b7e5 100644 --- a/client_test.go +++ b/client_test.go @@ -83,6 +83,62 @@ var _ = Describe("Client", func() { }) }) + Describe("Reading Errors", func() { + Context("send a notification and get an error", func() { + It("should not return an error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write([]byte("123456")) + c.Write([]byte{0}) + c.Close() + }) + defer s.Close() + + c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = c.Connect() + Expect(err).To(BeNil()) + + err = c.Send(apns.Notification{Identifier: 859059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + + nr := <-c.FailedNotifs + Expect(nr.Err).NotTo(BeNil()) + Expect(nr.Notif.Identifier).To(Equal(uint32(859059510))) + }) + }) + + Context("send a multiple notifications and get an error", func() { + It("should not return an error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write([]byte("123456")) + c.Write([]byte{0}) + c.Close() + }) + defer s.Close() + + c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = c.Connect() + Expect(err).To(BeNil()) + + err = c.Send(apns.Notification{Identifier: 859059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + + err = c.Send(apns.Notification{Identifier: 159059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + + err = c.Send(apns.Notification{Identifier: 259059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + + nr := <-c.FailedNotifs + Expect(nr.Err).NotTo(BeNil()) + Expect(nr.Notif.Identifier).To(Equal(uint32(859059510))) + }) + }) + }) + Describe("Send", func() { Context("valid push", func() { It("should not return an error", func() { @@ -104,7 +160,7 @@ var _ = Describe("Client", func() { }) Context("invalid notification", func() { - It("should not return an error", func() { + It("should return an error", func() { s := tcptest.NewTLSServer(func(c net.Conn) { c.Write([]byte{0}) c.Close() From df2b03d89d118abdc0afa0c8cee454f019fe0dfd Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 17 Feb 2015 00:01:48 -0500 Subject: [PATCH 16/32] Introduce session concept --- client.go | 213 ++++++++++-------------------------------------- client_test.go | 175 +++++++++++++++------------------------ session.go | 216 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 326 insertions(+), 278 deletions(-) create mode 100644 session.go diff --git a/client.go b/client.go index 8f1c157..b9159bb 100644 --- a/client.go +++ b/client.go @@ -1,70 +1,33 @@ package apns import ( - "container/list" "crypto/tls" - "io" "sync" + "time" ) -type buffer struct { - size int - *list.List -} - -func newBuffer(size int) *buffer { - return &buffer{size, list.New()} -} - -func (b *buffer) Add(v interface{}) *list.Element { - e := b.PushBack(v) - - if b.Len() > b.size { - b.Remove(b.Front()) - } - - return e -} - -type serialized struct { - id uint32 - b []byte - n *Notification -} - type Client struct { - Conn Conn - FailedNotifs chan NotificationResult + conn Conn - buffer *buffer - cursor *list.Element - - id uint32 - idm sync.Mutex - - connected bool - connm sync.Mutex + sess Session + sessm sync.Mutex } -func newClientWithConn(gw string, conn Conn) Client { - c := Client{ - Conn: conn, - FailedNotifs: make(chan NotificationResult), - buffer: newBuffer(50), - cursor: nil, - id: 0, - idm: sync.Mutex{}, - connected: false, - connm: sync.Mutex{}, +func newClientWithConn(conn Conn) (Client, error) { + c := Client{conn: conn} + + sess := newSession(conn) + err := sess.Connect() + if err != nil { + return c, err } - return c + return Client{conn, sess, sync.Mutex{}}, nil } -func NewClientWithCert(gw string, cert tls.Certificate) Client { +func NewClientWithCert(gw string, cert tls.Certificate) (Client, error) { conn := NewConnWithCert(gw, cert) - - return newClientWithConn(gw, conn) + return newClientWithConn(conn) } func NewClient(gw string, cert string, key string) (Client, error) { @@ -73,7 +36,7 @@ func NewClient(gw string, cert string, key string) (Client, error) { return Client{}, err } - return newClientWithConn(gw, conn), nil + return newClientWithConn(conn) } func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, error) { @@ -82,138 +45,48 @@ func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, err return Client{}, err } - return newClientWithConn(gw, conn), nil -} - -func (c *Client) Connect() error { - if err := c.Conn.Connect(); err != nil { - return err - } - - c.connected = true - - // On connect, requeue any notifications that were - // sent after the error & disconnect. - // http://redth.codes/the-problem-with-apples-push-notification-ser/ - if err := c.requeue(); err != nil { - return err - } - - // Kick off asynchronous error reading - go c.readErrors() - - return nil -} - -func (c *Client) disconnect() error { - c.connm.Lock() - defer c.connm.Unlock() - - if c.Conn == nil { - return nil - } - - return c.Conn.Close() + return newClientWithConn(conn) } func (c *Client) Send(n Notification) error { - if !c.connected { - return ErrDisconnected + if c.sess.Disconnected() { + c.reconnectAndRequeue() } - // Set identifier if not specified - n.Identifier = c.determineIdentifier(n.Identifier) - - b, err := n.ToBinary() - if err != nil { - return err - } - - // Add to list - c.cursor = c.buffer.Add(n) - - return c.send(b) + return c.sess.Send(n) } -func (c *Client) send(b []byte) error { - c.connm.Lock() - defer c.connm.Unlock() - - _, err := c.Conn.Write(b) - if err == io.EOF { - c.connected = false - return err - } +func (c *Client) reconnectAndRequeue() { + c.sessm.Lock() + defer c.sessm.Unlock() - if err != nil { - return err - } + // Pull off undelivered notifications + notifs := c.sess.RequeueableNotifications() - c.cursor = c.cursor.Next() - return nil -} + // Reconnect + c.sess = nil -func (c *Client) determineIdentifier(n uint32) uint32 { - c.idm.Lock() - defer c.idm.Unlock() - - // If the id passed in is 0, that means it wasn't - // set so get the next ID. Otherwise, set it to that - // identifier. - if n == 0 { - c.id++ - } else { - c.id = n - } - - return c.id -} + for c.sess == nil { + sess := newSession(c.conn) -func (c *Client) requeue() error { - // If `cursor` is not nil, this means there are notifications that - // need to be delivered (or redelivered) - for ; c.cursor != nil; c.cursor = c.cursor.Next() { - if s, ok := c.cursor.Value.(serialized); ok { - if err := c.Send(*s.n); err != nil { - return err - } + err := sess.Connect() + if err != nil { + // TODO retry policy + // TODO connect error channel + // Keep trying to connect + time.Sleep(1 * time.Second) + continue } - } - - return nil -} - -func (c *Client) readErrors() { - p := make([]byte, 6, 6) - _, err := c.Conn.Read(p) - // TODO(bw) not sure what to do here. It's unclear what errors - // come out of this and how we handle it. - if err != nil { - return + c.sess = sess } - e := NewError(p) - c.disconnect() - - cursor := c.buffer.Back() - - for cursor != nil { - // Get serialized notification - n, _ := cursor.Value.(Notification) - - // If the notification, move cursor after the trouble notification - if n.Identifier == e.Identifier { - // Try to write - skip if no one is reading on the other side - select { - case c.FailedNotifs <- NotificationResult{Notif: n, Err: e}: - default: - } - - c.cursor = cursor.Next() - c.buffer.Remove(cursor) - } - - cursor = cursor.Prev() + for _, n := range notifs { + // TODO handle error from sending + c.sess.Send(n) } } + +var newSession = func(c Conn) Session { + return NewSession(c) +} diff --git a/client_test.go b/client_test.go index 5a6b7e5..5fccc45 100644 --- a/client_test.go +++ b/client_test.go @@ -1,38 +1,80 @@ -package apns_test +package apns import ( + "errors" "io/ioutil" - "net" "os" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - - "github.com/timehop/apns" "github.com/timehop/tcptest" ) +type mockSession struct { + sendErr error +} + +func (m mockSession) Send(n Notification) error { + return m.sendErr +} + +func (m mockSession) Connect() error { + return nil +} + +func (m mockSession) RequeueableNotifications() []Notification { + return []Notification{} +} + +func (m mockSession) Disconnect() { +} + +func (m mockSession) Disconnected() bool { + return false +} + +type badConnMockSession struct { + mockSession +} + +func (_ badConnMockSession) Connect() error { + return errors.New("whatev") +} + var _ = Describe("Client", func() { + BeforeEach(func() { + newSession = func(_ Conn) Session { return mockSession{} } + }) + Describe(".NewClient", func() { Context("bad cert/key pair", func() { It("should error out", func() { - _, err := apns.NewClient(apns.ProductionGateway, "missing", "missing_also") + _, err := NewClient(ProductionGateway, "missing", "missing_also") Expect(err).NotTo(BeNil()) }) }) Context("valid cert/key pair", func() { It("should create a valid client", func() { - _, err := apns.NewClient(apns.SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + _, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) }) }) + + Context("bad connection", func() { + It("should error out", func() { + newSession = func(_ Conn) Session { return badConnMockSession{} } + + _, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).NotTo(BeNil()) + }) + }) }) Describe(".NewClientWithFiles", func() { Context("missing cert/key pair", func() { It("should error out", func() { - _, err := apns.NewClientWithFiles(apns.ProductionGateway, "missing", "missing_also") + _, err := NewClientWithFiles(ProductionGateway, "missing", "missing_also") Expect(err).NotTo(BeNil()) }) }) @@ -61,121 +103,38 @@ var _ = Describe("Client", func() { }) It("should create a valid client", func() { - _, err := apns.NewClientWithFiles(apns.ProductionGateway, certFile.Name(), keyFile.Name()) + _, err := NewClientWithFiles(ProductionGateway, certFile.Name(), keyFile.Name()) Expect(err).To(BeNil()) }) }) }) - Describe("Connect", func() { - It("should not return an error", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - c.Write([]byte{0}) - c.Close() - }) - defer s.Close() - - c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - Expect(err).To(BeNil()) - - err = c.Connect() - Expect(err).To(BeNil()) - }) - }) - - Describe("Reading Errors", func() { - Context("send a notification and get an error", func() { - It("should not return an error", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - c.Write([]byte("123456")) - c.Write([]byte{0}) - c.Close() + Describe("Send", func() { + Context("connected", func() { + Context("valid push", func() { + It("should not return an error", func() { + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = c.Send(Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) }) - defer s.Close() - - c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - Expect(err).To(BeNil()) - - err = c.Connect() - Expect(err).To(BeNil()) - - err = c.Send(apns.Notification{Identifier: 859059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) - Expect(err).To(BeNil()) - - nr := <-c.FailedNotifs - Expect(nr.Err).NotTo(BeNil()) - Expect(nr.Notif.Identifier).To(Equal(uint32(859059510))) }) - }) - - Context("send a multiple notifications and get an error", func() { - It("should not return an error", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - c.Write([]byte("123456")) - c.Write([]byte{0}) - c.Close() - }) - defer s.Close() - - c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - Expect(err).To(BeNil()) - - err = c.Connect() - Expect(err).To(BeNil()) - err = c.Send(apns.Notification{Identifier: 859059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) - Expect(err).To(BeNil()) + Context("invalid notification", func() { + It("should return an error", func() { + newSession = func(_ Conn) Session { return mockSession{sendErr: errors.New("")} } - err = c.Send(apns.Notification{Identifier: 159059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) - Expect(err).To(BeNil()) - - err = c.Send(apns.Notification{Identifier: 259059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) - Expect(err).To(BeNil()) - - nr := <-c.FailedNotifs - Expect(nr.Err).NotTo(BeNil()) - Expect(nr.Notif.Identifier).To(Equal(uint32(859059510))) - }) - }) - }) + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) - Describe("Send", func() { - Context("valid push", func() { - It("should not return an error", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - c.Write([]byte{0}) - c.Close() + err = c.Send(Notification{DeviceToken: "lol"}) + Expect(err).NotTo(BeNil()) }) - defer s.Close() - - c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - Expect(err).To(BeNil()) - - err = c.Connect() - Expect(err).To(BeNil()) - - err = c.Send(apns.Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) - Expect(err).To(BeNil()) }) }) - Context("invalid notification", func() { - It("should return an error", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - c.Write([]byte{0}) - c.Close() - }) - defer s.Close() - - c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - Expect(err).To(BeNil()) - - err = c.Connect() - Expect(err).To(BeNil()) - - err = c.Send(apns.Notification{DeviceToken: "lol"}) - Expect(err).NotTo(BeNil()) - }) + Context("disconnected", func() { }) }) }) diff --git a/session.go b/session.go new file mode 100644 index 0000000..ec9fd67 --- /dev/null +++ b/session.go @@ -0,0 +1,216 @@ +package apns + +import ( + "container/list" + "errors" + "io" + "sync" +) + +type SessionError struct { + Notification Notification + Err Error +} + +func (s SessionError) Error() string { + return s.Err.Error() +} + +type Session interface { + Send(n Notification) error + Connect() error + RequeueableNotifications() []Notification + Disconnect() + Disconnected() bool +} + +type buffer struct { + size int + m sync.Mutex + *list.List +} + +func newBuffer(size int) *buffer { + return &buffer{size, sync.Mutex{}, list.New()} +} + +func (b *buffer) Add(v interface{}) *list.Element { + b.m.Lock() + defer b.m.Unlock() + + e := b.PushBack(v) + + if b.Len() > b.size { + b.Remove(b.Front()) + } + + return e +} + +type sessionState int + +const ( + sessionStateNew sessionState = iota + sessionStateConnected sessionState = iota + sessionStateDisconnected sessionState = iota +) + +type session struct { + b *buffer + + conn Conn + connm sync.Mutex + + st sessionState + stm sync.Mutex + + id uint32 + idm sync.Mutex + + err SessionError +} + +func NewSession(conn Conn) Session { + return &session{ + st: sessionStateNew, + stm: sync.Mutex{}, + conn: conn, + connm: sync.Mutex{}, + idm: sync.Mutex{}, + b: newBuffer(50), + } +} + +func (s *session) Connect() error { + if s.st != sessionStateNew { + return errors.New("can't connect unless the session is new") + } + + go s.readErrors() + return nil +} + +func (s *session) Disconnected() bool { + return s.st == sessionStateDisconnected +} + +func (s *session) Send(n Notification) error { + // If disconnected, error out + if s.st != sessionStateConnected { + return errors.New("not connected") + } + + // Set identifier if not specified + n.Identifier = s.determineIdentifier(n.Identifier) + + // Serialize + b, err := n.ToBinary() + if err != nil { + return err + } + + // Add to buffer + s.b.Add(n) + + // Send synchronously + return s.send(b) +} + +func (s *session) send(b []byte) error { + s.connm.Lock() + defer s.connm.Unlock() + + _, err := s.conn.Write(b) + if err == io.EOF { + s.Disconnect() + return err + } + + if err != nil { + return err + } + + return nil +} + +func (s *session) Disconnect() { + // Disconnect + s.transitionState(sessionStateDisconnected) +} + +func (s *session) RequeueableNotifications() []Notification { + notifs := []Notification{} + + // If still connected, return nothing + if s.st != sessionStateDisconnected { + return notifs + } + + // Walk back to last known good notification and return the slice + var e *list.Element + for e = s.b.Front(); e != nil; e = e.Next() { + if n, ok := e.Value.(Notification); ok && n.Identifier == s.err.Notification.Identifier { + break + } + } + + // Start right after errored ID and get the rest of the list + for e = e.Next(); e != nil; e = e.Next() { + n, ok := e.Value.(Notification) + if !ok { + continue + } + + notifs = append(notifs, n) + } + + return notifs +} + +func (s *session) transitionState(st sessionState) { + s.stm.Lock() + defer s.stm.Unlock() + + s.st = st +} + +func (s *session) determineIdentifier(n uint32) uint32 { + s.idm.Lock() + defer s.idm.Unlock() + + // If the id passed in is 0, that means it wasn't + // set so get the next ID. Otherwise, set it to that + // identifier. + if n == 0 { + s.id++ + } else { + s.id = n + } + + return s.id +} + +func (s *session) readErrors() { + p := make([]byte, 6, 6) + + _, err := s.conn.Read(p) + // TODO(bw) not sure what to do here. It's unclear what errors + // come out of this and how we handle it. + if err != nil { + return + } + + s.Disconnect() + + e := NewError(p) + + for cursor := s.b.Back(); cursor != nil; cursor = cursor.Prev() { + // Get serialized notification + n, _ := cursor.Value.(Notification) + + // If the notification, move cursor after the trouble notification + if n.Identifier == e.Identifier { + s.err = SessionError{n, e} + } + } +} From 06384bf4c305db3b984fdf9dcdf0994645d4249d Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 17 Feb 2015 00:05:47 -0500 Subject: [PATCH 17/32] Fix example --- example/example.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/example/example.go b/example/example.go index 2d9e5c0..637d0af 100644 --- a/example/example.go +++ b/example/example.go @@ -13,10 +13,6 @@ func main() { log.Fatal("Could not create client: ", err.Error()) } - if err := c.Connect(); err != nil { - log.Fatal("Could not create connect: ", err.Error()) - } - i := 1 for { fmt.Print("Enter ' ': ") From d807d6b2557fc25ca7076bc6ce833cac0892fab8 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 17 Feb 2015 18:18:05 -0500 Subject: [PATCH 18/32] Add more tests for client --- client_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 77 insertions(+), 12 deletions(-) diff --git a/client_test.go b/client_test.go index 5fccc45..7dbee14 100644 --- a/client_test.go +++ b/client_test.go @@ -11,30 +11,41 @@ import ( ) type mockSession struct { - sendErr error + sendCB func(n Notification) error + requeueNotifs []Notification + disconnectedState bool } -func (m mockSession) Send(n Notification) error { - return m.sendErr +func (m *mockSession) Send(n Notification) error { + if m.sendCB == nil { + return nil + } + + return m.sendCB(n) } -func (m mockSession) Connect() error { +func (m *mockSession) Connect() error { return nil } -func (m mockSession) RequeueableNotifications() []Notification { - return []Notification{} +func (m *mockSession) RequeueableNotifications() []Notification { + if len(m.requeueNotifs) == 0 { + return []Notification{} + } + + return m.requeueNotifs } -func (m mockSession) Disconnect() { +func (m *mockSession) Disconnect() { + m.disconnectedState = true } -func (m mockSession) Disconnected() bool { - return false +func (m *mockSession) Disconnected() bool { + return m.disconnectedState } type badConnMockSession struct { - mockSession + *mockSession } func (_ badConnMockSession) Connect() error { @@ -43,7 +54,7 @@ func (_ badConnMockSession) Connect() error { var _ = Describe("Client", func() { BeforeEach(func() { - newSession = func(_ Conn) Session { return mockSession{} } + newSession = func(_ Conn) Session { return &mockSession{} } }) Describe(".NewClient", func() { @@ -123,7 +134,13 @@ var _ = Describe("Client", func() { Context("invalid notification", func() { It("should return an error", func() { - newSession = func(_ Conn) Session { return mockSession{sendErr: errors.New("")} } + newSession = func(_ Conn) Session { + return &mockSession{ + sendCB: func(_ Notification) error { + return errors.New("") + }, + } + } c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) @@ -135,6 +152,54 @@ var _ = Describe("Client", func() { }) Context("disconnected", func() { + It("should reconnect", func() { + newSessCount := 0 + newSession = func(_ Conn) Session { + newSessCount += 1 + return &mockSession{} + } + + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + c.sess.Disconnect() + + err = c.Send(Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + + Expect(newSessCount).To(Equal(2)) + }) + }) + + It("should reconnect and requeue", func() { + newSessCount := 0 + sendCount := 0 + + newSession = func(_ Conn) Session { + newSessCount += 1 + return &mockSession{ + requeueNotifs: []Notification{ + Notification{}, + Notification{}, + Notification{}, + }, + sendCB: func(_ Notification) error { + sendCount += 1 + return nil + }, + } + } + + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + c.sess.Disconnect() + + err = c.Send(Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + + Expect(newSessCount).To(Equal(2)) + Expect(sendCount).To(Equal(4)) }) }) }) From 12a1dfa8729f90ca20f81e25df55eb230773c65d Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 25 Feb 2015 17:13:43 -0500 Subject: [PATCH 19/32] Beginnings of a session test --- session_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 session_test.go diff --git a/session_test.go b/session_test.go new file mode 100644 index 0000000..67a598f --- /dev/null +++ b/session_test.go @@ -0,0 +1,94 @@ +package apns + +import ( + "time" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type mockConn struct{} + +func (m mockConn) Read(b []byte) (int, error) { + return 0, nil +} + +func (m mockConn) Write(b []byte) (int, error) { + return 0, nil +} + +func (m mockConn) Close() error { + return nil +} + +func (m mockConn) Connect() error { + return nil +} + +func (m mockConn) ReadWithTimeout(p []byte, deadline time.Time) (int, error) { + return 0, nil +} + +var _ = Describe("Session", func() { + Describe("NewSession", func() { + It("creates a session", func() { + s := NewSession(mockConn{}) + Expect(s).NotTo(BeNil()) + }) + }) + + Describe("Connect", func() { + Context("new state", func() { + It("should not return an error", func() { + s := NewSession(mockConn{}) + + err := s.Connect() + Expect(err).To(BeNil()) + }) + }) + + Context("not new state", func() { + It("should return an error", func() { + sess := NewSession(mockConn{}) + + s := sess.(*session) + s.transitionState(sessionStateDisconnected) + + err := s.Connect() + Expect(err).NotTo(BeNil()) + }) + }) + }) + + Describe("Disconnected", func() { + Context("not connected", func() { + It("should not be true", func() { + sess := NewSession(mockConn{}) + + s := sess.(*session) + s.transitionState(sessionStateDisconnected) + + Expect(s.Disconnected()).To(BeTrue()) + }) + }) + + Context("connected", func() { + It("should be false", func() { + sess := NewSession(mockConn{}) + + s := sess.(*session) + s.Connect() + + Expect(s.Disconnected()).To(BeFalse()) + }) + }) + }) + + Describe("Send", func() { + }) + + Describe("Disconnect", func() { + }) + + Describe("RequeueableNotifications", func() { + }) +}) From f87a8ee10c8ef470fe1052ff28edf88fec84a25d Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Thu, 26 Feb 2015 16:30:00 -0500 Subject: [PATCH 20/32] Synchronize around the session state --- session.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/session.go b/session.go index ec9fd67..c7dd832 100644 --- a/session.go +++ b/session.go @@ -82,7 +82,7 @@ func NewSession(conn Conn) Session { } func (s *session) Connect() error { - if s.st != sessionStateNew { + if s.isNew() { return errors.New("can't connect unless the session is new") } @@ -90,13 +90,30 @@ func (s *session) Connect() error { return nil } +func (s *session) isNew() bool { + s.stm.Lock() + defer s.stm.Unlock() + + return s.st != sessionStateNew +} + func (s *session) Disconnected() bool { + s.stm.Lock() + defer s.stm.Unlock() + return s.st == sessionStateDisconnected } +func (s *session) Connnected() bool { + s.stm.Lock() + defer s.stm.Unlock() + + return s.st == sessionStateConnected +} + func (s *session) Send(n Notification) error { // If disconnected, error out - if s.st != sessionStateConnected { + if s.Connnected() { return errors.New("not connected") } From 1456b9b1c120f3c7f9617d2396e191fc263e51f8 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 4 Mar 2015 15:49:23 -0500 Subject: [PATCH 21/32] How do you even logic --- session.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/session.go b/session.go index c7dd832..dc7e416 100644 --- a/session.go +++ b/session.go @@ -113,7 +113,7 @@ func (s *session) Connnected() bool { func (s *session) Send(n Notification) error { // If disconnected, error out - if s.Connnected() { + if !s.Connnected() { return errors.New("not connected") } From f841b7ef66998b519efb438938cf3c656db0ec1a Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 4 Mar 2015 15:51:18 -0500 Subject: [PATCH 22/32] Clean up session states --- session.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/session.go b/session.go index dc7e416..36e4576 100644 --- a/session.go +++ b/session.go @@ -50,9 +50,9 @@ func (b *buffer) Add(v interface{}) *list.Element { type sessionState int const ( - sessionStateNew sessionState = iota - sessionStateConnected sessionState = iota - sessionStateDisconnected sessionState = iota + sessionStateNew sessionState = 1 << iota + sessionStateConnected + sessionStateDisconnected ) type session struct { @@ -143,11 +143,7 @@ func (s *session) send(b []byte) error { return err } - if err != nil { - return err - } - - return nil + return err } func (s *session) Disconnect() { From b8a8e426c85ad9ef6d543cfff23ab5507a7d2e94 Mon Sep 17 00:00:00 2001 From: Nathan Youngman Date: Tue, 21 Apr 2015 11:53:48 -0600 Subject: [PATCH 23/32] Revise Travis CI config use containers, Go 1.4, and update cover location --- .travis.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 687544f..31c7754 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,12 +1,11 @@ +sudo: false language: go go: - - 1.3 -services: - - redis-server + - 1.4.2 before_script: - go get github.com/onsi/ginkgo - go get github.com/onsi/gomega - - go get code.google.com/p/go.tools/cmd/cover + - go get golang.org/x/tools/cmd/cover - go install github.com/onsi/ginkgo/ginkgo script: ginkgo -r --skipMeasurements --cover --trace --race env: From 915e3b9d671681af7b9f5cdace2ff4c4dbbe1a5c Mon Sep 17 00:00:00 2001 From: Nathan Youngman Date: Tue, 21 Apr 2015 11:57:04 -0600 Subject: [PATCH 24/32] travis ci: notifications run on all pull requests and any commits directly to master or develop. don't send emails. --- .travis.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.travis.yml b/.travis.yml index 31c7754..0987969 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,3 +11,11 @@ script: ginkgo -r --skipMeasurements --cover --trace --race env: global: - PATH=$HOME/gopath/bin:$PATH + +notifications: + email: false + +branches: + only: + - master + - develop From 509f1d781778dbd565e0229b1782414af3819678 Mon Sep 17 00:00:00 2001 From: Nathan Youngman Date: Thu, 23 Apr 2015 13:39:48 -0600 Subject: [PATCH 25/32] document exported identifiers and lint fixes --- client.go | 5 +++++ client_test.go | 8 ++++---- conn.go | 14 ++++++++++---- doc.go | 7 +++---- error.go | 34 +++++++++++++++++++++------------- feedback.go | 5 +++++ feedback_test.go | 5 +++-- notification.go | 23 ++++++++++++++++------- session.go | 14 +++++++++++--- 9 files changed, 78 insertions(+), 37 deletions(-) diff --git a/client.go b/client.go index b9159bb..689d87e 100644 --- a/client.go +++ b/client.go @@ -6,6 +6,7 @@ import ( "time" ) +// Client creates a session with Apple and handles reconnection. type Client struct { conn Conn @@ -25,11 +26,13 @@ func newClientWithConn(conn Conn) (Client, error) { return Client{conn, sess, sync.Mutex{}}, nil } +// NewClientWithCert creates a client of the Apple gateway given a certificate. func NewClientWithCert(gw string, cert tls.Certificate) (Client, error) { conn := NewConnWithCert(gw, cert) return newClientWithConn(conn) } +// NewClient is a helper that accepts a certificate/key pair. func NewClient(gw string, cert string, key string) (Client, error) { conn, err := NewConn(gw, cert, key) if err != nil { @@ -39,6 +42,7 @@ func NewClient(gw string, cert string, key string) (Client, error) { return newClientWithConn(conn) } +// NewClientWithFiles is a helper that loads a certificate/key from files. func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, error) { conn, err := NewConnWithFiles(gw, certFile, keyFile) if err != nil { @@ -48,6 +52,7 @@ func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, err return newClientWithConn(conn) } +// Send a notification, handling disconnections. func (c *Client) Send(n Notification) error { if c.sess.Disconnected() { c.reconnectAndRequeue() diff --git a/client_test.go b/client_test.go index 7dbee14..5363eb4 100644 --- a/client_test.go +++ b/client_test.go @@ -48,7 +48,7 @@ type badConnMockSession struct { *mockSession } -func (_ badConnMockSession) Connect() error { +func (m badConnMockSession) Connect() error { return errors.New("whatev") } @@ -155,7 +155,7 @@ var _ = Describe("Client", func() { It("should reconnect", func() { newSessCount := 0 newSession = func(_ Conn) Session { - newSessCount += 1 + newSessCount++ return &mockSession{} } @@ -176,7 +176,7 @@ var _ = Describe("Client", func() { sendCount := 0 newSession = func(_ Conn) Session { - newSessCount += 1 + newSessCount++ return &mockSession{ requeueNotifs: []Notification{ Notification{}, @@ -184,7 +184,7 @@ var _ = Describe("Client", func() { Notification{}, }, sendCB: func(_ Notification) error { - sendCount += 1 + sendCount++ return nil }, } diff --git a/conn.go b/conn.go index b14f8b8..89722d7 100644 --- a/conn.go +++ b/conn.go @@ -9,11 +9,15 @@ import ( ) const ( + // ProductionGateway is the host for Apple Push Notification server. ProductionGateway = "gateway.push.apple.com:2195" - SandboxGateway = "gateway.sandbox.push.apple.com:2195" + // SandboxGateway is Apple's gateway for development. + SandboxGateway = "gateway.sandbox.push.apple.com:2195" + // ProductionFeedbackGateway is Apple's feedback service. ProductionFeedbackGateway = "feedback.push.apple.com:2196" - SandboxFeedbackGateway = "feedback.sandbox.push.apple.com:2196" + // SandboxFeedbackGateway is Apple's feedback service for development. + SandboxFeedbackGateway = "feedback.sandbox.push.apple.com:2196" ) // Conn is a wrapper for the actual TLS connections made to Apple @@ -32,6 +36,7 @@ type conn struct { connected bool } +// NewConnWithCert creates a new Conn from a certificate. func NewConnWithCert(gw string, cert tls.Certificate) Conn { gatewayParts := strings.Split(gw, ":") tls := tls.Config{ @@ -43,7 +48,7 @@ func NewConnWithCert(gw string, cert tls.Certificate) Conn { return &conn{gateway: gw, tls: &tls} } -// NewConnWithFiles creates a new Conn from certificate and key in the specified files +// NewConn creates a new Conn from certificate and key pair. func NewConn(gw string, crt string, key string) (Conn, error) { cert, err := tls.X509KeyPair([]byte(crt), []byte(key)) if err != nil { @@ -53,7 +58,7 @@ func NewConn(gw string, crt string, key string) (Conn, error) { return NewConnWithCert(gw, cert), nil } -// NewConnWithFiles creates a new Conn from certificate and key in the specified files +// NewConnWithFiles creates a new Conn from certificate and key in the specified files. func NewConnWithFiles(gw string, certFile string, keyFile string) (Conn, error) { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { @@ -79,6 +84,7 @@ func (c *conn) Connect() error { return nil } +// Close the connection. func (c *conn) Close() error { if c.netConn != nil { return c.netConn.Close() diff --git a/doc.go b/doc.go index 02149d7..85b219e 100644 --- a/doc.go +++ b/doc.go @@ -1,13 +1,12 @@ /* -A Go package to interface with the Apple Push -Notification Service +Package apns provides an interface with the Apple Push Notification Service. Features This library implements a few features that we couldn't find in any one library elsewhere: - Long Lived Clients - Apple's documentation say that you should hold a + Long Lived Clients - Apple's documentation says that you should hold a persistent connection open and not create new connections for every payload See: https://developer.apple.com/library/ios/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/Chapters/CommunicatingWIthAPS.html#//apple_ref/doc/uid/TP40008194-CH101-SW6) @@ -16,7 +15,7 @@ library elsewhere: variable length payloads. This library uses that protocol. - Robust Send Guarantees - APNS has asynchronous feedback on whether a push + Robust Send Guarantees - Apple has asynchronous feedback on whether a push sent. That means that if you send pushes after a bad send, those pushes will be lost forever. Our library records the last N pushes, detects errors, diff --git a/error.go b/error.go index 3371bea..a744076 100644 --- a/error.go +++ b/error.go @@ -3,26 +3,32 @@ package apns import ( "bytes" "encoding/binary" - "errors" -) - -var ( - ErrDisconnected = errors.New("disconnected from gateway") ) const ( // Error strings based on the codes specified here: // https://developer.apple.com/library/ios/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/Chapters/CommunicatingWIthAPS.html#//apple_ref/doc/uid/TP40008194-CH101-SW12 - ErrProcessing = "Processing error" + + // ErrProcessing (1) + ErrProcessing = "Processing error" + // ErrMissingDeviceToken (2) ErrMissingDeviceToken = "Missing device token" - ErrMissingTopic = "Missing topic" - ErrMissingPayload = "Missing payload" - ErrInvalidTokenSize = "Invalid token size" - ErrInvalidTopicSize = "Invalid topic size" + // ErrMissingTopic (3) + ErrMissingTopic = "Missing topic" + // ErrMissingPayload (4) when no payload. + ErrMissingPayload = "Missing payload" + // ErrInvalidTokenSize (5) for a device token that is the wrong size. + ErrInvalidTokenSize = "Invalid token size" + // ErrInvalidTopicSize (6) + ErrInvalidTopicSize = "Invalid topic size" + // ErrInvalidPayloadSize (7) for a payload over 2 KB. ErrInvalidPayloadSize = "Invalid payload size" - ErrInvalidToken = "Invalid token" - ErrShutdown = "Shutdown" - ErrUnknown = "None (unknown)" + // ErrInvalidToken (8) such as a production device token used with the sandbox gateway. + ErrInvalidToken = "Invalid token" + // ErrShutdown (10) closed connection to perform maintenance. Open a new connection. + ErrShutdown = "Shutdown" + // ErrUnknown (255) + ErrUnknown = "None (unknown)" ) var errorMapping = map[uint8]string{ @@ -38,6 +44,7 @@ var errorMapping = map[uint8]string{ 255: ErrUnknown, } +// Error captures the details of an error read from Apple's Push Notification server. type Error struct { Command uint8 Status uint8 @@ -45,6 +52,7 @@ type Error struct { ErrStr string } +// NewError parses an error from Apple. func NewError(p []byte) Error { if len(p) != 1+1+4 { return Error{ErrStr: ErrUnknown} diff --git a/feedback.go b/feedback.go index a0093b1..41cb707 100644 --- a/feedback.go +++ b/feedback.go @@ -8,10 +8,12 @@ import ( "time" ) +// Feedback is a connection to Apple's feedback service. type Feedback struct { Conn Conn } +// FeedbackTuple represents the feedback received from Apple. type FeedbackTuple struct { Timestamp time.Time TokenLength uint16 @@ -37,12 +39,14 @@ func feedbackTupleFromBytes(b []byte) FeedbackTuple { } } +// NewFeedbackWithCert creates a new feedback service client with a certificate. func NewFeedbackWithCert(gw string, cert tls.Certificate) Feedback { conn := NewConnWithCert(gw, cert) return Feedback{Conn: conn} } +// NewFeedback creates a new feedback service client with a certificate/key pair. func NewFeedback(gw string, cert string, key string) (Feedback, error) { conn, err := NewConn(gw, cert, key) if err != nil { @@ -52,6 +56,7 @@ func NewFeedback(gw string, cert string, key string) (Feedback, error) { return Feedback{Conn: conn}, nil } +// NewFeedbackWithFiles creates a new feedback service client from certificate and key files. func NewFeedbackWithFiles(gw string, certFile string, keyFile string) (Feedback, error) { conn, err := NewConnWithFiles(gw, certFile, keyFile) if err != nil { diff --git a/feedback_test.go b/feedback_test.go index 96dadb4..938ffcd 100644 --- a/feedback_test.go +++ b/feedback_test.go @@ -9,6 +9,7 @@ import ( "net" "os" "time" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/timehop/apns" @@ -84,7 +85,7 @@ var _ = Describe("Feedback", func() { r := 0 for _ = range c { - r += 1 + r++ } Expect(r).To(Equal(0)) @@ -104,7 +105,7 @@ var _ = Describe("Feedback", func() { r := 0 for _ = range c { - r += 1 + r++ } Expect(r).To(Equal(0)) diff --git a/notification.go b/notification.go index 7fe827e..ff0e7fa 100644 --- a/notification.go +++ b/notification.go @@ -11,7 +11,9 @@ import ( ) const ( - PriorityImmediate = 10 + // PriorityImmediate for time sensitive notifications (not for silent push messages). + PriorityImmediate = 10 + // PriorityPowerConserve for notifications that are less time sensitive. PriorityPowerConserve = 5 ) @@ -36,13 +38,8 @@ const ( priorityItemLength = 1 ) -type NotificationResult struct { - Notif Notification - Err Error -} - +// Alert to send. type Alert struct { - // Do not add fields without updating the implementation of isZero. Body string `json:"body,omitempty"` Title string `json:"title,omitempty"` Action string `json:"action,omitempty"` @@ -50,8 +47,11 @@ type Alert struct { LocArgs []string `json:"loc-args,omitempty"` ActionLocKey string `json:"action-loc-key,omitempty"` LaunchImage string `json:"launch-image,omitempty"` + + // Do not add fields without updating the implementation of isZero. } +// isSimple alerts only contain a Body. func (a *Alert) isSimple() bool { return len(a.Title) == 0 && len(a.Action) == 0 && len(a.LocKey) == 0 && len(a.LocArgs) == 0 && len(a.ActionLocKey) == 0 && len(a.LaunchImage) == 0 } @@ -60,6 +60,7 @@ func (a *Alert) isZero() bool { return a.isSimple() && len(a.Body) == 0 } +// APS is the Apple-reserved aps namespace in a push notification. type APS struct { Alert Alert Badge *int // 0 to clear notifications, nil to leave as is. @@ -69,6 +70,7 @@ type APS struct { Category string // requires iOS 8+ } +// MarshalJSON implements the json.Marshaler interface. func (aps APS) MarshalJSON() ([]byte, error) { data := make(map[string]interface{}) @@ -98,6 +100,7 @@ func (aps APS) MarshalJSON() ([]byte, error) { return json.Marshal(data) } +// Payload to send Apple. type Payload struct { APS APS // MDM for mobile device management @@ -105,6 +108,7 @@ type Payload struct { customValues map[string]interface{} } +// MarshalJSON implements the json.Marshaler interface. func (p *Payload) MarshalJSON() ([]byte, error) { if len(p.MDM) != 0 { p.customValues["mdm"] = p.MDM @@ -115,6 +119,7 @@ func (p *Payload) MarshalJSON() ([]byte, error) { return json.Marshal(p.customValues) } +// SetCustomValue sets a custom payload value. func (p *Payload) SetCustomValue(key string, value interface{}) error { if key == "aps" { return errors.New("cannot assign a custom APS value in payload") @@ -125,6 +130,7 @@ func (p *Payload) SetCustomValue(key string, value interface{}) error { return nil } +// Notification contains the payload. type Notification struct { ID string DeviceToken string @@ -134,14 +140,17 @@ type Notification struct { Payload *Payload } +// NewNotification creates a new notification. func NewNotification() Notification { return Notification{Payload: NewPayload()} } +// NewPayload creates a new payload. func NewPayload() *Payload { return &Payload{customValues: map[string]interface{}{}} } +// ToBinary encodes a notification to send it. func (n Notification) ToBinary() ([]byte, error) { b := []byte{} diff --git a/session.go b/session.go index 36e4576..5742c5e 100644 --- a/session.go +++ b/session.go @@ -7,6 +7,7 @@ import ( "sync" ) +// SessionError associates an error from Apple to a notification. type SessionError struct { Notification Notification Err Error @@ -16,6 +17,7 @@ func (s SessionError) Error() string { return s.Err.Error() } +// Session to Apple's Push Notification server. type Session interface { Send(n Notification) error Connect() error @@ -70,6 +72,7 @@ type session struct { err SessionError } +// NewSession creates a new session. func NewSession(conn Conn) Session { return &session{ st: sessionStateNew, @@ -81,6 +84,7 @@ func NewSession(conn Conn) Session { } } +// Connect session to gateway. func (s *session) Connect() error { if s.isNew() { return errors.New("can't connect unless the session is new") @@ -97,6 +101,7 @@ func (s *session) isNew() bool { return s.st != sessionStateNew } +// Disconnected indicates whether session is disconnected. func (s *session) Disconnected() bool { s.stm.Lock() defer s.stm.Unlock() @@ -104,16 +109,18 @@ func (s *session) Disconnected() bool { return s.st == sessionStateDisconnected } -func (s *session) Connnected() bool { +// Connected indicates whether session is connected. +func (s *session) Connected() bool { s.stm.Lock() defer s.stm.Unlock() return s.st == sessionStateConnected } +// Send notification to gateway. func (s *session) Send(n Notification) error { // If disconnected, error out - if !s.Connnected() { + if !s.Connected() { return errors.New("not connected") } @@ -146,11 +153,12 @@ func (s *session) send(b []byte) error { return err } +// Disconnect from gateway. func (s *session) Disconnect() { - // Disconnect s.transitionState(sessionStateDisconnected) } +// RequeueableNotifications returns good notifications sent after an error. func (s *session) RequeueableNotifications() []Notification { notifs := []Notification{} From 89c02e19d2b239639cae1cba7fb09f1bc86469a5 Mon Sep 17 00:00:00 2001 From: Nathan Youngman Date: Thu, 23 Apr 2015 15:49:49 -0600 Subject: [PATCH 26/32] ErrUnrecognizedErrorResponse --- error.go | 15 ++++++++++----- error_test.go | 6 +++--- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/error.go b/error.go index a744076..20b390b 100644 --- a/error.go +++ b/error.go @@ -5,6 +5,11 @@ import ( "encoding/binary" ) +var ( + // ErrUnrecognizedErrorResponse when the error from Apple isn't recognized. + ErrUnrecognizedErrorResponse = "Unrecognized error or no error." +) + const ( // Error strings based on the codes specified here: // https://developer.apple.com/library/ios/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/Chapters/CommunicatingWIthAPS.html#//apple_ref/doc/uid/TP40008194-CH101-SW12 @@ -44,9 +49,9 @@ var errorMapping = map[uint8]string{ 255: ErrUnknown, } -// Error captures the details of an error read from Apple's Push Notification server. +// Error captures the error response read from Apple's Push Notification server. type Error struct { - Command uint8 + Command uint8 // always should be 8 Status uint8 Identifier uint32 ErrStr string @@ -54,8 +59,8 @@ type Error struct { // NewError parses an error from Apple. func NewError(p []byte) Error { - if len(p) != 1+1+4 { - return Error{ErrStr: ErrUnknown} + if len(p) != 6 { + return Error{ErrStr: ErrUnrecognizedErrorResponse} } r := bytes.NewBuffer(p) @@ -67,7 +72,7 @@ func NewError(p []byte) Error { var ok bool if e.ErrStr, ok = errorMapping[e.Status]; !ok { - e.ErrStr = ErrUnknown + e.ErrStr = ErrUnrecognizedErrorResponse } return e diff --git a/error_test.go b/error_test.go index 1c6bb78..12d063a 100644 --- a/error_test.go +++ b/error_test.go @@ -79,14 +79,14 @@ var _ = Describe("Error", func() { }) Context("error with unrecognized code", func() { - ShouldBeErrorWithErrStr(300, apns.ErrUnknown) + ShouldBeErrorWithErrStr(300, apns.ErrUnrecognizedErrorResponse) }) Context("not enough bytes", func() { - It("should be ErrUnknown", func() { + It("should be ErrUnrecognizedErrorResponse", func() { e := apns.NewError([]byte{}) Expect(e).NotTo(BeNil()) - Expect(e.ErrStr).To(Equal(apns.ErrUnknown)) + Expect(e.ErrStr).To(Equal(apns.ErrUnrecognizedErrorResponse)) }) }) }) From 933c627b3383289a830ec14aa0338369bbc385ad Mon Sep 17 00:00:00 2001 From: Nathan Youngman Date: Thu, 23 Apr 2015 16:30:14 -0600 Subject: [PATCH 27/32] split out SetReadDeadline used by feedback service. --- apns_suite_test.go | 11 +++++------ conn.go | 13 ++++++------- feedback.go | 8 +++++++- feedback_test.go | 3 +-- session_test.go | 5 +++-- 5 files changed, 22 insertions(+), 18 deletions(-) diff --git a/apns_suite_test.go b/apns_suite_test.go index 858a5b9..7d40d9f 100644 --- a/apns_suite_test.go +++ b/apns_suite_test.go @@ -11,7 +11,7 @@ import ( type mockConn struct { connect func() error read func([]byte) (int, error) - readWithTimeout func([]byte, time.Time) (int, error) + setReadDeadline func(time.Time) error } func (m *mockConn) Connect() error { @@ -37,12 +37,11 @@ func (m *mockConn) Close() error { return nil } -func (m *mockConn) ReadWithTimeout(b []byte, t time.Time) (int, error) { - if m.readWithTimeout != nil { - return m.readWithTimeout(b, t) +func (m *mockConn) SetReadDeadline(t time.Time) error { + if m.setReadDeadline != nil { + return m.setReadDeadline(t) } - - return 0, nil + return nil } func TestApns(t *testing.T) { diff --git a/conn.go b/conn.go index 89722d7..d76577b 100644 --- a/conn.go +++ b/conn.go @@ -25,7 +25,7 @@ type Conn interface { io.ReadWriteCloser Connect() error - ReadWithTimeout(p []byte, deadline time.Time) (int, error) + SetReadDeadline(deadline time.Time) error } type conn struct { @@ -36,7 +36,7 @@ type conn struct { connected bool } -// NewConnWithCert creates a new Conn from a certificate. +// NewConnWithCert creates a new connection from a certificate. func NewConnWithCert(gw string, cert tls.Certificate) Conn { gatewayParts := strings.Split(gw, ":") tls := tls.Config{ @@ -98,11 +98,10 @@ func (c *conn) Read(p []byte) (int, error) { return c.netConn.Read(p) } -// ReadWithTimeout reads data from the connection and returns an error -// after duration -func (c *conn) ReadWithTimeout(p []byte, deadline time.Time) (int, error) { - c.netConn.SetReadDeadline(deadline) - return c.netConn.Read(p) +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *conn) SetReadDeadline(deadline time.Time) error { + return c.netConn.SetReadDeadline(deadline) } // Write writes data from the connection diff --git a/feedback.go b/feedback.go index 41cb707..80d0e8a 100644 --- a/feedback.go +++ b/feedback.go @@ -85,7 +85,13 @@ func (f Feedback) receive(fc chan FeedbackTuple) { for { b := make([]byte, 38) - _, err := f.Conn.ReadWithTimeout(b, time.Now().Add(100*time.Millisecond)) + err = f.Conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + if err != nil { + close(fc) + return + } + + _, err = f.Conn.Read(b) if err != nil { close(fc) return diff --git a/feedback_test.go b/feedback_test.go index 938ffcd..59739e2 100644 --- a/feedback_test.go +++ b/feedback_test.go @@ -8,7 +8,6 @@ import ( "io/ioutil" "net" "os" - "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -95,7 +94,7 @@ var _ = Describe("Feedback", func() { Context("times out", func() { It("should not receive anything", func() { m := mockConn{ - readWithTimeout: func(b []byte, t time.Time) (int, error) { + read: func(b []byte) (int, error) { return 0, net.UnknownNetworkError("") }, } diff --git a/session_test.go b/session_test.go index 67a598f..aaabc1b 100644 --- a/session_test.go +++ b/session_test.go @@ -2,6 +2,7 @@ package apns import ( "time" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -24,8 +25,8 @@ func (m mockConn) Connect() error { return nil } -func (m mockConn) ReadWithTimeout(p []byte, deadline time.Time) (int, error) { - return 0, nil +func (m mockConn) SetReadDeadline(deadline time.Time) error { + return nil } var _ = Describe("Session", func() { From 54aab15acdbae3068036a2e00b222c8cc91d4a0c Mon Sep 17 00:00:00 2001 From: Nathan Youngman Date: Fri, 15 May 2015 16:08:32 -0600 Subject: [PATCH 28/32] keep NotificationResult don't break API for now --- session.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/session.go b/session.go index 5742c5e..4de3a13 100644 --- a/session.go +++ b/session.go @@ -7,13 +7,13 @@ import ( "sync" ) -// SessionError associates an error from Apple to a notification. -type SessionError struct { - Notification Notification - Err Error +// NotificationResult associates an error from Apple to a notification. +type NotificationResult struct { + Notif Notification + Err Error } -func (s SessionError) Error() string { +func (s NotificationResult) Error() string { return s.Err.Error() } @@ -69,7 +69,7 @@ type session struct { id uint32 idm sync.Mutex - err SessionError + err NotificationResult } // NewSession creates a new session. @@ -170,7 +170,7 @@ func (s *session) RequeueableNotifications() []Notification { // Walk back to last known good notification and return the slice var e *list.Element for e = s.b.Front(); e != nil; e = e.Next() { - if n, ok := e.Value.(Notification); ok && n.Identifier == s.err.Notification.Identifier { + if n, ok := e.Value.(Notification); ok && n.Identifier == s.err.Notif.Identifier { break } } @@ -231,7 +231,7 @@ func (s *session) readErrors() { // If the notification, move cursor after the trouble notification if n.Identifier == e.Identifier { - s.err = SessionError{n, e} + s.err = NotificationResult{n, e} } } } From 3fa96efe89c7556ab52de09226d858041f3ce4c0 Mon Sep 17 00:00:00 2001 From: Nathan Youngman Date: Wed, 20 May 2015 12:26:39 -0600 Subject: [PATCH 29/32] remove determineIdentifier (auto ID) at least for now, may bring this functionality back ref #19 --- session.go | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/session.go b/session.go index 4de3a13..1cedffa 100644 --- a/session.go +++ b/session.go @@ -66,9 +66,6 @@ type session struct { st sessionState stm sync.Mutex - id uint32 - idm sync.Mutex - err NotificationResult } @@ -79,7 +76,6 @@ func NewSession(conn Conn) Session { stm: sync.Mutex{}, conn: conn, connm: sync.Mutex{}, - idm: sync.Mutex{}, b: newBuffer(50), } } @@ -124,9 +120,6 @@ func (s *session) Send(n Notification) error { return errors.New("not connected") } - // Set identifier if not specified - n.Identifier = s.determineIdentifier(n.Identifier) - // Serialize b, err := n.ToBinary() if err != nil { @@ -195,22 +188,6 @@ func (s *session) transitionState(st sessionState) { s.st = st } -func (s *session) determineIdentifier(n uint32) uint32 { - s.idm.Lock() - defer s.idm.Unlock() - - // If the id passed in is 0, that means it wasn't - // set so get the next ID. Otherwise, set it to that - // identifier. - if n == 0 { - s.id++ - } else { - s.id = n - } - - return s.id -} - func (s *session) readErrors() { p := make([]byte, 6, 6) From 0b6d1b05d3ee62f0d15d301255bb5fd8659e9938 Mon Sep 17 00:00:00 2001 From: Nathan Youngman Date: Wed, 20 May 2015 14:09:26 -0600 Subject: [PATCH 30/32] move sent buffer code and session error is just the error from Apple --- buffer.go | 86 +++++++++++++++++++++++++++++++++++++++++++++++++ buffer_test.go | 11 +++++++ session.go | 82 +++------------------------------------------- session_test.go | 3 -- 4 files changed, 101 insertions(+), 81 deletions(-) create mode 100644 buffer.go create mode 100644 buffer_test.go diff --git a/buffer.go b/buffer.go new file mode 100644 index 0000000..dcc1d2e --- /dev/null +++ b/buffer.go @@ -0,0 +1,86 @@ +package apns + +import ( + "container/list" + "sync" +) + +// circular buffer of sent messages +// to retry if connection is dropped +type buffer struct { + size int + m sync.Mutex + *list.List +} + +func newBuffer(size int) *buffer { + return &buffer{size, sync.Mutex{}, list.New()} +} + +func (b *buffer) Add(v interface{}) *list.Element { + b.m.Lock() + defer b.m.Unlock() + + e := b.PushBack(v) + + if b.Len() > b.size { + b.Remove(b.Front()) + } + + return e +} + +// NotificationResult associates an error from Apple to a notification. +type NotificationResult struct { + Notif Notification + Err Error +} + +func (s NotificationResult) Error() string { + return s.Err.Error() +} + +func (s *session) FindFailedNotification() NotificationResult { + e := s.err + + for cursor := s.b.Back(); cursor != nil; cursor = cursor.Prev() { + // Get serialized notification + n, _ := cursor.Value.(Notification) + + // If the notification, move cursor after the trouble notification + if n.Identifier == e.Identifier { + return NotificationResult{n, e} + } + } + return NotificationResult{Notification{}, e} +} + +// RequeueableNotifications returns good notifications sent after an error. +func (s *session) RequeueableNotifications() []Notification { + notifs := []Notification{} + + // If still connected, return nothing + if s.st != sessionStateDisconnected { + return notifs + } + + // Walk back to last known good notification and return the slice + var e *list.Element + for e = s.b.Front(); e != nil; e = e.Next() { + if n, ok := e.Value.(Notification); ok && n.Identifier == s.err.Identifier { + break + } + } + + // Start right after errored ID and get the rest of the list + for e = e.Next(); e != nil; e = e.Next() { + n, ok := e.Value.(Notification) + if !ok { + continue + } + + notifs = append(notifs, n) + } + + return notifs +} diff --git a/buffer_test.go b/buffer_test.go new file mode 100644 index 0000000..d6260c4 --- /dev/null +++ b/buffer_test.go @@ -0,0 +1,11 @@ +package apns + +import ( + . "github.com/onsi/ginkgo" + // . "github.com/onsi/gomega" +) + +var _ = Describe("Session", func() { + Describe("RequeueableNotifications", func() { + }) +}) diff --git a/session.go b/session.go index 1cedffa..df861a3 100644 --- a/session.go +++ b/session.go @@ -1,22 +1,11 @@ package apns import ( - "container/list" "errors" "io" "sync" ) -// NotificationResult associates an error from Apple to a notification. -type NotificationResult struct { - Notif Notification - Err Error -} - -func (s NotificationResult) Error() string { - return s.Err.Error() -} - // Session to Apple's Push Notification server. type Session interface { Send(n Notification) error @@ -26,29 +15,6 @@ type Session interface { Disconnected() bool } -type buffer struct { - size int - m sync.Mutex - *list.List -} - -func newBuffer(size int) *buffer { - return &buffer{size, sync.Mutex{}, list.New()} -} - -func (b *buffer) Add(v interface{}) *list.Element { - b.m.Lock() - defer b.m.Unlock() - - e := b.PushBack(v) - - if b.Len() > b.size { - b.Remove(b.Front()) - } - - return e -} - type sessionState int const ( @@ -66,7 +32,7 @@ type session struct { st sessionState stm sync.Mutex - err NotificationResult + err Error } // NewSession creates a new session. @@ -130,10 +96,10 @@ func (s *session) Send(n Notification) error { s.b.Add(n) // Send synchronously - return s.send(b) + return s.write(b) } -func (s *session) send(b []byte) error { +func (s *session) write(b []byte) error { s.connm.Lock() defer s.connm.Unlock() @@ -151,36 +117,6 @@ func (s *session) Disconnect() { s.transitionState(sessionStateDisconnected) } -// RequeueableNotifications returns good notifications sent after an error. -func (s *session) RequeueableNotifications() []Notification { - notifs := []Notification{} - - // If still connected, return nothing - if s.st != sessionStateDisconnected { - return notifs - } - - // Walk back to last known good notification and return the slice - var e *list.Element - for e = s.b.Front(); e != nil; e = e.Next() { - if n, ok := e.Value.(Notification); ok && n.Identifier == s.err.Notif.Identifier { - break - } - } - - // Start right after errored ID and get the rest of the list - for e = e.Next(); e != nil; e = e.Next() { - n, ok := e.Value.(Notification) - if !ok { - continue - } - - notifs = append(notifs, n) - } - - return notifs -} - func (s *session) transitionState(st sessionState) { s.stm.Lock() defer s.stm.Unlock() @@ -200,15 +136,5 @@ func (s *session) readErrors() { s.Disconnect() - e := NewError(p) - - for cursor := s.b.Back(); cursor != nil; cursor = cursor.Prev() { - // Get serialized notification - n, _ := cursor.Value.(Notification) - - // If the notification, move cursor after the trouble notification - if n.Identifier == e.Identifier { - s.err = NotificationResult{n, e} - } - } + s.err = NewError(p) } diff --git a/session_test.go b/session_test.go index aaabc1b..e70cbfb 100644 --- a/session_test.go +++ b/session_test.go @@ -89,7 +89,4 @@ var _ = Describe("Session", func() { Describe("Disconnect", func() { }) - - Describe("RequeueableNotifications", func() { - }) }) From 57489e737e466091e7b62703b647ba414cedce97 Mon Sep 17 00:00:00 2001 From: Nathan Youngman Date: Wed, 20 May 2015 14:46:02 -0600 Subject: [PATCH 31/32] find requeueable notifications on buffer --- buffer.go | 17 +++++------------ client.go | 16 ++++++++++------ session.go | 4 ++++ 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/buffer.go b/buffer.go index dcc1d2e..23ba78f 100644 --- a/buffer.go +++ b/buffer.go @@ -40,10 +40,8 @@ func (s NotificationResult) Error() string { return s.Err.Error() } -func (s *session) FindFailedNotification() NotificationResult { - e := s.err - - for cursor := s.b.Back(); cursor != nil; cursor = cursor.Prev() { +func (b *buffer) FindFailedNotification(e Error) NotificationResult { + for cursor := b.Back(); cursor != nil; cursor = cursor.Prev() { // Get serialized notification n, _ := cursor.Value.(Notification) @@ -56,18 +54,13 @@ func (s *session) FindFailedNotification() NotificationResult { } // RequeueableNotifications returns good notifications sent after an error. -func (s *session) RequeueableNotifications() []Notification { +func (b *buffer) RequeueableNotifications(identifier uint32) []Notification { notifs := []Notification{} - // If still connected, return nothing - if s.st != sessionStateDisconnected { - return notifs - } - // Walk back to last known good notification and return the slice var e *list.Element - for e = s.b.Front(); e != nil; e = e.Next() { - if n, ok := e.Value.(Notification); ok && n.Identifier == s.err.Identifier { + for e = b.Front(); e != nil; e = e.Next() { + if n, ok := e.Value.(Notification); ok && n.Identifier == identifier { break } } diff --git a/client.go b/client.go index 689d87e..52c34d3 100644 --- a/client.go +++ b/client.go @@ -68,7 +68,15 @@ func (c *Client) reconnectAndRequeue() { // Pull off undelivered notifications notifs := c.sess.RequeueableNotifications() - // Reconnect + c.reconnect() + + for _, n := range notifs { + // TODO handle error from sending + c.sess.Send(n) + } +} + +func (c *Client) reconnect() { c.sess = nil for c.sess == nil { @@ -85,13 +93,9 @@ func (c *Client) reconnectAndRequeue() { c.sess = sess } - - for _, n := range notifs { - // TODO handle error from sending - c.sess.Send(n) - } } +// newSession for altering in tests var newSession = func(c Conn) Session { return NewSession(c) } diff --git a/session.go b/session.go index df861a3..5a5fc3b 100644 --- a/session.go +++ b/session.go @@ -71,6 +71,10 @@ func (s *session) Disconnected() bool { return s.st == sessionStateDisconnected } +func (s *session) RequeueableNotifications() []Notification { + return s.b.RequeueableNotifications(s.err.Identifier) +} + // Connected indicates whether session is connected. func (s *session) Connected() bool { s.stm.Lock() From 55cf34ec03aab6f39099d34e13c06f469f202fb3 Mon Sep 17 00:00:00 2001 From: Nathan Youngman Date: Wed, 20 May 2015 14:52:41 -0600 Subject: [PATCH 32/32] catch json marshal errors --- notification.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/notification.go b/notification.go index ff0e7fa..1dd8d5d 100644 --- a/notification.go +++ b/notification.go @@ -163,7 +163,10 @@ func (n Notification) ToBinary() ([]byte, error) { return b, fmt.Errorf("convert token to hex error: %s", err) } - j, _ := json.Marshal(n.Payload) + j, err := json.Marshal(n.Payload) + if err != nil { + return b, fmt.Errorf("json marshal error: %s", err) + } buf := bytes.NewBuffer(b)