diff --git a/pkg/tlsx/tls/tls.go b/pkg/tlsx/tls/tls.go index c07a5ed2..f9fb1102 100644 --- a/pkg/tlsx/tls/tls.go +++ b/pkg/tlsx/tls/tls.go @@ -210,8 +210,26 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con threads = len(toEnumerate) } + // Build a context that respects the global timeout so cipher enumeration + // cannot block forever if a host stops responding mid-scan. + // context.Background() was previously used here, causing pool.Acquire to + // hang indefinitely when all pool connections were exhausted (issue #819). + enumCtx := context.Background() + var enumCancel context.CancelFunc + if c.options.Timeout > 0 { + // Give the whole enumeration a generous per-host ceiling: + // timeout * (number of ciphers / concurrency + 1) keeps things moving + // without cutting off legitimately slow hosts. + perHostDeadline := time.Duration(c.options.Timeout) * time.Second * + time.Duration((len(toEnumerate)/threads)+1) + enumCtx, enumCancel = context.WithTimeout(context.Background(), perHostDeadline) + } else { + enumCtx, enumCancel = context.WithCancel(context.Background()) + } + defer enumCancel() + // setup connection pool - pool, err := connpool.NewOneTimePool(context.Background(), address, threads) + pool, err := connpool.NewOneTimePool(enumCtx, address, threads) if err != nil { return enumeratedCiphers, errorutil.NewWithErr(err).Msgf("failed to setup connection pool") //nolint } @@ -227,8 +245,14 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con for _, v := range toEnumerate { // create new baseConn and pass it to tlsclient - baseConn, err := pool.Acquire(context.Background()) + // Use enumCtx (with timeout) instead of context.Background() so + // Acquire unblocks when the overall enumeration deadline expires. + baseConn, err := pool.Acquire(enumCtx) if err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + // Timeout hit: return what we have so far instead of hanging. + return enumeratedCiphers, nil + } return enumeratedCiphers, errorutil.NewWithErr(err).WithTag("ctls") //nolint } stats.IncrementCryptoTLSConnections() @@ -236,7 +260,9 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con conn := tls.Client(baseConn, baseCfg) - if err := conn.Handshake(); err == nil { + // Use HandshakeContext so the enumeration deadline/cancellation is + // respected during the handshake itself, not just during pool.Acquire. + if err := conn.HandshakeContext(enumCtx); err == nil { ciphersuite := conn.ConnectionState().CipherSuite enumeratedCiphers = append(enumeratedCiphers, tls.CipherSuiteName(ciphersuite)) } diff --git a/pkg/tlsx/ztls/ztls.go b/pkg/tlsx/ztls/ztls.go index a03b7267..2e43c73e 100644 --- a/pkg/tlsx/ztls/ztls.go +++ b/pkg/tlsx/ztls/ztls.go @@ -226,8 +226,23 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con threads = len(toEnumerate) } + // Build a context that respects the global timeout so cipher enumeration + // cannot block forever if a host stops responding mid-scan. + // context.Background() / context.TODO() were previously used here, causing + // pool.Acquire and tlsHandshakeWithTimeout to hang indefinitely (issue #819). + enumCtx := context.Background() + var enumCancel context.CancelFunc + if c.options.Timeout > 0 { + perHostDeadline := time.Duration(c.options.Timeout) * time.Second * + time.Duration((len(toEnumerate)/threads)+1) + enumCtx, enumCancel = context.WithTimeout(context.Background(), perHostDeadline) + } else { + enumCtx, enumCancel = context.WithCancel(context.Background()) + } + defer enumCancel() + // setup connection pool - pool, err := connpool.NewOneTimePool(context.Background(), address, threads) + pool, err := connpool.NewOneTimePool(enumCtx, address, threads) if err != nil { return enumeratedCiphers, errorutil.NewWithErr(err).Msgf("failed to setup connection pool") //nolint } @@ -249,15 +264,20 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con gologger.Debug().Label("ztls").Msgf("Starting cipher enumeration with %v ciphers in %v", len(toEnumerate), options.VersionTLS) for _, v := range toEnumerate { - baseConn, err := pool.Acquire(context.Background()) + // Use enumCtx so Acquire unblocks when the enumeration deadline expires. + baseConn, err := pool.Acquire(enumCtx) if err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return enumeratedCiphers, nil + } return enumeratedCiphers, errorutil.NewWithErr(err).WithTag("ztls") //nolint } stats.IncrementZcryptoTLSConnections() conn := tls.Client(baseConn, baseCfg) baseCfg.CipherSuites = []uint16{ztlsCiphers[v]} - if err := c.tlsHandshakeWithTimeout(conn, context.TODO()); err == nil { + // Use enumCtx instead of context.TODO() to propagate the timeout. + if err := c.tlsHandshakeWithTimeout(conn, enumCtx); err == nil { h1 := conn.GetHandshakeLog() enumeratedCiphers = append(enumeratedCiphers, h1.ServerHello.CipherSuite.String()) } @@ -320,20 +340,32 @@ func (c *Client) getConfig(hostname, ip, port string, options clients.ConnectOpt return config, nil } -// tlsHandshakeWithCtx attempts tls handshake with given timeout +// tlsHandshakeWithTimeout attempts a TLS handshake and returns when it +// completes or when ctx is cancelled/expired, whichever comes first. +// +// The previous implementation had a subtle bug: the select case +// +// case errChan <- tlsConn.Handshake(): +// +// evaluates Handshake() synchronously before the select statement is +// entered. If the handshake blocks indefinitely (e.g. the server stops +// responding mid-handshake) the ctx.Done() case is never evaluated and the +// function hangs forever. The fix runs the handshake in a goroutine so +// that both channels can be observed concurrently. func (c *Client) tlsHandshakeWithTimeout(tlsConn *tls.Conn, ctx context.Context) error { errChan := make(chan error, 1) - defer close(errChan) + + go func() { + errChan <- tlsConn.Handshake() + }() select { case <-ctx.Done(): return errorutil.NewWithTag("ztls", "timeout while attempting handshake") //nolint - case errChan <- tlsConn.Handshake(): - } - - err := <-errChan - if err == tls.ErrCertsOnly { - err = nil + case err := <-errChan: + if err == tls.ErrCertsOnly { + return nil + } + return err } - return err }