diff --git a/client.go b/client.go index de7ab1a..5d9eeb4 100644 --- a/client.go +++ b/client.go @@ -48,12 +48,15 @@ func newClientWithConn(gw string, conn Conn) Client { return c } -func NewClientWithCert(gw string, cert tls.Certificate) Client { - conn := NewConnWithCert(gw, cert) - +func NewClietWithCertTimeout(gw string, cert tls.Certificate, timeout int) Client { + conn := NewConnWithCertTimeout(gw, cert, timeout) return newClientWithConn(gw, conn) } +func NewClientWithCert(gw string, cert tls.Certificate) Client { + return NewClietWithCertTimeout(gw, cert, 0) +} + func NewClient(gw string, cert string, key string) (Client, error) { conn, err := NewConn(gw, cert, key) if err != nil { @@ -63,8 +66,8 @@ func NewClient(gw string, cert string, key string) (Client, error) { return newClientWithConn(gw, conn), nil } -func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, error) { - conn, err := NewConnWithFiles(gw, certFile, keyFile) +func NewClientWithFilesTimeout(gw string, certFile string, keyFile string, timeout int) (Client, error) { + conn, err := NewConnWithFilesTimeout(gw, certFile, keyFile, timeout) if err != nil { return Client{}, err } @@ -72,6 +75,10 @@ func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, err return newClientWithConn(gw, conn), nil } +func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, error) { + return NewClientWithFilesTimeout(gw, certFile, keyFile, 0) +} + func (c *Client) Send(n Notification) error { c.notifs <- n return nil @@ -132,6 +139,7 @@ func (c *Client) runLoop() { if err != nil { // TODO Probably want to exponentially backoff... time.Sleep(1 * time.Second) + log.Println("err connecting to apns ", err.Error()) continue } @@ -205,9 +213,11 @@ func readErrs(c *Conn) chan error { go func() { p := make([]byte, 6, 6) + c.NetConn.SetReadDeadline(time.Time{}) _, err := c.Read(p) if err != nil { errs <- err + log.Println("read err", err.Error()) return } diff --git a/client_test.go b/client_test.go index c9dfd47..65c5a39 100644 --- a/client_test.go +++ b/client_test.go @@ -3,12 +3,12 @@ package apns_test import ( "bytes" "encoding/binary" - "io/ioutil" - "os" - "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/timehop/apns" + "io/ioutil" + "os" + "time" ) var _ = Describe("Client", func() { diff --git a/conn.go b/conn.go index d3aa712..a23113f 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net" "strings" + "time" ) const ( @@ -18,19 +19,24 @@ const ( type Conn struct { NetConn net.Conn Conf *tls.Config + timeout time.Duration gateway string connected bool } -func NewConnWithCert(gw string, cert tls.Certificate) Conn { +func NewConnWithCertTimeout(gw string, cert tls.Certificate, timeout int) Conn { gatewayParts := strings.Split(gw, ":") conf := tls.Config{ Certificates: []tls.Certificate{cert}, ServerName: gatewayParts[0], } - return Conn{gateway: gw, Conf: &conf} + return Conn{gateway: gw, Conf: &conf, timeout: time.Duration(timeout) * time.Second} +} + +func NewConnWithCert(gw string, cert tls.Certificate) Conn { + return NewConnWithCertTimeout(gw, cert, 0) } // NewConnWithFiles creates a new Conn from certificate and key in the specified files @@ -44,13 +50,18 @@ func NewConn(gw string, crt string, key string) (Conn, error) { } // NewConnWithFiles creates a new Conn from certificate and key in the specified files -func NewConnWithFiles(gw string, certFile string, keyFile string) (Conn, error) { +func NewConnWithFilesTimeout(gw string, certFile string, keyFile string, timeout int) (Conn, error) { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return Conn{}, err } - return NewConnWithCert(gw, cert), nil + return NewConnWithCertTimeout(gw, cert, timeout), nil +} + +// NewConnWithFiles creates a new Conn from certificate and key in the specified files +func NewConnWithFiles(gw string, certFile string, keyFile string) (Conn, error) { + return NewConnWithFilesTimeout(gw, certFile, keyFile, 0) } // Connect actually creates the TLS connection @@ -60,12 +71,21 @@ func (c *Conn) Connect() error { c.NetConn.Close() } - conn, err := net.Dial("tcp", c.gateway) + var conn net.Conn + var err error + if c.timeout > 0 { + conn, err = net.DialTimeout("tcp", c.gateway, c.timeout) + } else { + conn, err = net.Dial("tcp", c.gateway) + } if err != nil { return err } tlsConn := tls.Client(conn, c.Conf) + if c.timeout > 0 { + tlsConn.SetDeadline(time.Now().Add(c.timeout * 3)) + } err = tlsConn.Handshake() if err != nil { return err @@ -91,5 +111,8 @@ func (c *Conn) Read(p []byte) (int, error) { // Write writes data from the connection func (c *Conn) Write(p []byte) (int, error) { + if c.timeout > 0 { + c.NetConn.SetWriteDeadline(time.Now().Add(c.timeout)) + } return c.NetConn.Write(p) } diff --git a/conn_test.go b/conn_test.go index e910e6c..1ca4f79 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,6 +4,9 @@ import ( "bytes" "crypto/tls" "fmt" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/timehop/apns" "io" "io/ioutil" "log" @@ -11,9 +14,6 @@ import ( "os" "strings" "time" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - "github.com/timehop/apns" ) var DummyCert = `-----BEGIN CERTIFICATE----- diff --git a/feedback_test.go b/feedback_test.go index 29978b4..dcfe620 100644 --- a/feedback_test.go +++ b/feedback_test.go @@ -4,12 +4,12 @@ import ( "bytes" "encoding/binary" "encoding/hex" - "io/ioutil" - "os" - "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/timehop/apns" + "io/ioutil" + "os" + "time" ) var _ = Describe("Feedback", func() {