diff --git a/internal/dev/server.go b/internal/dev/server.go index e10a5c8e..e4318944 100644 --- a/internal/dev/server.go +++ b/internal/dev/server.go @@ -29,6 +29,12 @@ import ( "golang.org/x/net/http2" ) +const ( + maxConnectionFailures = 20 + maxReconnectBaseDelay = time.Millisecond * 250 + maxReconnectMaxDelay = time.Second * 10 +) + var propagator propagation.TraceContext type Server struct { @@ -55,10 +61,16 @@ type Server struct { expiresAt *time.Time tlsCertificate *tls.Certificate conn *tls.Conn - srv *http2.Server wg sync.WaitGroup serverAddr string cleanup func() + + // Connection state + connectionLock sync.Mutex + reconnectFailures int + connectionFailed time.Time + connectionStarted time.Time + reconnectMutex sync.Mutex } type ServerArgs struct { @@ -147,10 +159,45 @@ func (s *Server) reconnect() { func (s *Server) connect(initial bool) { var gerr error + + // hold a connection lock to prevent multiple go routines from trying to reconnect + // before the previous connect goroutine has finished + s.connectionLock.Lock() + defer s.connectionLock.Unlock() + defer func() { if initial && gerr != nil { s.connected <- gerr.Error() } + s.logger.Debug("connection closed") + select { + case <-s.ctx.Done(): + return + default: + var count int + var started time.Time + s.reconnectMutex.Lock() + if s.reconnectFailures == 0 { + s.connectionFailed = time.Now() + s.logger.Warn("lost connection to the dev server, reconnecting ...") + } + s.reconnectFailures++ + started = s.connectionFailed + count = s.reconnectFailures + s.reconnectMutex.Unlock() + if count >= maxConnectionFailures { + s.logger.Fatal("Too many connection failures, giving up after %d attempts (%s). You may need to re-run `agentuity dev`. If this error persists, please contact support.", count, time.Since(started)) + return + } + baseDelay := maxReconnectBaseDelay + wait := baseDelay * time.Duration(math.Pow(2, float64(count-1))) + if wait > maxReconnectMaxDelay { + wait = maxReconnectMaxDelay + } + s.logger.Debug("reconnecting in %s after %d connection failures (%s)", wait, count, time.Since(started)) + time.Sleep(wait) + s.reconnect() + } }() if err := s.refreshConnection(); err != nil { @@ -174,7 +221,7 @@ func (s *Server) connect(initial bool) { conn, err := tls.Dial("tcp", s.serverAddr, &tlsConfig) if err != nil { gerr = err - s.logger.Error("failed to dial tls: %s", err) + s.logger.Warn("failed to dial tls: %s, will retry ...", err) return } s.conn = conn @@ -183,6 +230,19 @@ func (s *Server) connect(initial bool) { s.connected <- "" } + // if we successfully connect, reset our connection failures + s.reconnectMutex.Lock() + if s.reconnectFailures > 0 && !s.connectionFailed.IsZero() { + s.logger.Debug("reconnection successful after %s (%d attempts)", time.Since(s.connectionFailed), s.reconnectFailures) + s.logger.Info("✅ connection to the dev server re-established") + } + s.reconnectFailures = 0 + s.connectionStarted = time.Now() + s.connectionFailed = time.Time{} + s.reconnectMutex.Unlock() + + s.logger.Debug("connection established to %s", s.serverAddr) + // HTTP/2 server to accept proxied requests over the tunnel connection h2s := &http2.Server{}