Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions pkg/tlsx/tls/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,26 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con
threads = len(toEnumerate)
}

// Build a context that respects the global timeout so cipher enumeration
// cannot block forever if a host stops responding mid-scan.
// context.Background() was previously used here, causing pool.Acquire to
// hang indefinitely when all pool connections were exhausted (issue #819).
enumCtx := context.Background()
var enumCancel context.CancelFunc
if c.options.Timeout > 0 {
// Give the whole enumeration a generous per-host ceiling:
// timeout * (number of ciphers / concurrency + 1) keeps things moving
// without cutting off legitimately slow hosts.
perHostDeadline := time.Duration(c.options.Timeout) * time.Second *
time.Duration((len(toEnumerate)/threads)+1)
enumCtx, enumCancel = context.WithTimeout(context.Background(), perHostDeadline)
} else {
enumCtx, enumCancel = context.WithCancel(context.Background())
}
defer enumCancel()

// setup connection pool
pool, err := connpool.NewOneTimePool(context.Background(), address, threads)
pool, err := connpool.NewOneTimePool(enumCtx, address, threads)
if err != nil {
return enumeratedCiphers, errorutil.NewWithErr(err).Msgf("failed to setup connection pool") //nolint
}
Expand All @@ -227,16 +245,24 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con

for _, v := range toEnumerate {
// create new baseConn and pass it to tlsclient
baseConn, err := pool.Acquire(context.Background())
// Use enumCtx (with timeout) instead of context.Background() so
// Acquire unblocks when the overall enumeration deadline expires.
baseConn, err := pool.Acquire(enumCtx)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
// Timeout hit: return what we have so far instead of hanging.
return enumeratedCiphers, nil
}
return enumeratedCiphers, errorutil.NewWithErr(err).WithTag("ctls") //nolint
}
stats.IncrementCryptoTLSConnections()
baseCfg.CipherSuites = []uint16{tlsCiphers[v]}

conn := tls.Client(baseConn, baseCfg)

if err := conn.Handshake(); err == nil {
// Use HandshakeContext so the enumeration deadline/cancellation is
// respected during the handshake itself, not just during pool.Acquire.
if err := conn.HandshakeContext(enumCtx); err == nil {
ciphersuite := conn.ConnectionState().CipherSuite
enumeratedCiphers = append(enumeratedCiphers, tls.CipherSuiteName(ciphersuite))
}
Expand Down
56 changes: 44 additions & 12 deletions pkg/tlsx/ztls/ztls.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,23 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con
threads = len(toEnumerate)
}

// Build a context that respects the global timeout so cipher enumeration
// cannot block forever if a host stops responding mid-scan.
// context.Background() / context.TODO() were previously used here, causing
// pool.Acquire and tlsHandshakeWithTimeout to hang indefinitely (issue #819).
enumCtx := context.Background()
var enumCancel context.CancelFunc
if c.options.Timeout > 0 {
perHostDeadline := time.Duration(c.options.Timeout) * time.Second *
time.Duration((len(toEnumerate)/threads)+1)
enumCtx, enumCancel = context.WithTimeout(context.Background(), perHostDeadline)
} else {
enumCtx, enumCancel = context.WithCancel(context.Background())
}
defer enumCancel()

// setup connection pool
pool, err := connpool.NewOneTimePool(context.Background(), address, threads)
pool, err := connpool.NewOneTimePool(enumCtx, address, threads)
if err != nil {
return enumeratedCiphers, errorutil.NewWithErr(err).Msgf("failed to setup connection pool") //nolint
}
Expand All @@ -249,15 +264,20 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con
gologger.Debug().Label("ztls").Msgf("Starting cipher enumeration with %v ciphers in %v", len(toEnumerate), options.VersionTLS)

for _, v := range toEnumerate {
baseConn, err := pool.Acquire(context.Background())
// Use enumCtx so Acquire unblocks when the enumeration deadline expires.
baseConn, err := pool.Acquire(enumCtx)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return enumeratedCiphers, nil
}
return enumeratedCiphers, errorutil.NewWithErr(err).WithTag("ztls") //nolint
}
stats.IncrementZcryptoTLSConnections()
conn := tls.Client(baseConn, baseCfg)
baseCfg.CipherSuites = []uint16{ztlsCiphers[v]}

if err := c.tlsHandshakeWithTimeout(conn, context.TODO()); err == nil {
// Use enumCtx instead of context.TODO() to propagate the timeout.
if err := c.tlsHandshakeWithTimeout(conn, enumCtx); err == nil {
h1 := conn.GetHandshakeLog()
enumeratedCiphers = append(enumeratedCiphers, h1.ServerHello.CipherSuite.String())
}
Expand Down Expand Up @@ -320,20 +340,32 @@ func (c *Client) getConfig(hostname, ip, port string, options clients.ConnectOpt
return config, nil
}

// tlsHandshakeWithCtx attempts tls handshake with given timeout
// tlsHandshakeWithTimeout attempts a TLS handshake and returns when it
// completes or when ctx is cancelled/expired, whichever comes first.
//
// The previous implementation had a subtle bug: the select case
//
// case errChan <- tlsConn.Handshake():
//
// evaluates Handshake() synchronously before the select statement is
// entered. If the handshake blocks indefinitely (e.g. the server stops
// responding mid-handshake) the ctx.Done() case is never evaluated and the
// function hangs forever. The fix runs the handshake in a goroutine so
// that both channels can be observed concurrently.
func (c *Client) tlsHandshakeWithTimeout(tlsConn *tls.Conn, ctx context.Context) error {
errChan := make(chan error, 1)
defer close(errChan)

go func() {
errChan <- tlsConn.Handshake()
}()

select {
case <-ctx.Done():
return errorutil.NewWithTag("ztls", "timeout while attempting handshake") //nolint
case errChan <- tlsConn.Handshake():
}

err := <-errChan
if err == tls.ErrCertsOnly {
err = nil
case err := <-errChan:
if err == tls.ErrCertsOnly {
return nil
}
return err
}
return err
}