Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions cmd/smokescreen.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ func NewConfiguration(args []string, logger *log.Logger) (*smokescreen.Config, e
Value: smokescreen.DefaultMaxRequestBurst,
Usage: "Maximum burst capacity for rate limiting.\n\t\tMust be greater than max-request-rate when specified.\n\t\tOmit to use default (2x max-request-rate).",
},
cli.IntFlag{
Name: "max-concurrent-connect-tunnels",
Value: smokescreen.DefaultMaxConcurrentConnectTunnels,
Usage: "Maximum number of concurrent CONNECT tunnels.\n\t\tUnlike max-concurrent-requests, this limits actual long-lived connections.\n\t\t0 = unlimited (default).",
},
cli.DurationFlag{
Name: "dns-timeout",
Value: smokescreen.DefaultDNSTimeout,
Expand Down Expand Up @@ -333,6 +338,11 @@ func NewConfiguration(args []string, logger *log.Logger) (*smokescreen.Config, e
}
}

if c.IsSet("max-concurrent-connect-tunnels") {
maxTunnels := c.Int("max-concurrent-connect-tunnels")
conf.MaxConcurrentConnectTunnels = maxTunnels
}

if c.IsSet("dns-timeout") {
conf.DNSTimeout = c.Duration("dns-timeout")
}
Expand Down
40 changes: 29 additions & 11 deletions pkg/smokescreen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ import (
// Configuration defaults
const (
// Server defaults
DefaultPort uint16 = 4750
DefaultConnectTimeout = 10 * time.Second
DefaultExitTimeout = 500 * time.Minute
DefaultNetwork = "ip"
DefaultStatsSocketFileMode = 0700
DefaultPort uint16 = 4750
DefaultConnectTimeout = 10 * time.Second
DefaultExitTimeout = 500 * time.Minute
DefaultNetwork = "ip"
DefaultStatsSocketFileMode = 0700

// HTTP server timeouts
DefaultReadHeaderTimeout = 300 * time.Second
Expand All @@ -52,9 +52,10 @@ const (
DefaultStatsdAddress = "127.0.0.1:8200"

// Rate limiting defaults
DefaultMaxConcurrentRequests = 0 // 0 = unlimited
DefaultMaxRequestRate = 0.0 // 0 = unlimited
DefaultMaxRequestBurst = -1 // -1 = use 2x rate
DefaultMaxConcurrentRequests = 0 // 0 = unlimited
DefaultMaxRequestRate = 0.0 // 0 = unlimited
DefaultMaxRequestBurst = -1 // -1 = use 2x rate
DefaultMaxConcurrentConnectTunnels = 0 // 0 = unlimited
)

type RuleRange struct {
Expand Down Expand Up @@ -178,6 +179,15 @@ type Config struct {
// Set to 0 to use default (2x MaxRequestRate).
MaxRequestBurst int

// MaxConcurrentConnectTunnels limits the number of active CONNECT tunnels.
// Unlike MaxConcurrentRequests which only limits request processing time,
// this limits the actual number of long-lived tunnel connections.
// Set to 0 to disable tunnel limiting.
MaxConcurrentConnectTunnels int

// TunnelLimiter is used internally to enforce MaxConcurrentConnectTunnels.
TunnelLimiter *TunnelLimiter

// DNSTimeout is the maximum time to wait for DNS resolution.
// Set to 0 to use default (5 seconds).
DNSTimeout time.Duration
Expand Down Expand Up @@ -343,22 +353,30 @@ func (config *Config) SetRateLimits(maxConcurrent int, maxRate float64, maxReque
if maxRate < 0 {
return fmt.Errorf("maxRate must be >= 0, got %.2f", maxRate)
}

if maxRequestBurst >= 0 && maxRate > 0 && float64(maxRequestBurst) <= maxRate {
return fmt.Errorf("maxRequestBurst (%d) must be greater than maxRequestRate (%.2f); omit to use default (2x rate)", maxRequestBurst, maxRate)
}

// Apply default: 2x rate when not explicitly configured or configured negative
if maxRequestBurst < 0 {
maxRequestBurst = int(maxRate * 2)
}

config.MaxConcurrentRequests = maxConcurrent
config.MaxRequestRate = maxRate
config.MaxRequestBurst = maxRequestBurst
return nil
}

// initializeTunnelLimiter creates the TunnelLimiter if MaxConcurrentConnectTunnels
// is set but the limiter hasn't been initialized yet.
func (config *Config) initializeTunnelLimiter() {
if config.MaxConcurrentConnectTunnels > 0 && config.TunnelLimiter == nil {
config.TunnelLimiter = NewTunnelLimiter(config.MaxConcurrentConnectTunnels, config)
}
}

// RFC 5280, 4.2.1.1
type authKeyId struct {
Id []byte `asn1:"optional,tag:0"`
Expand Down
10 changes: 9 additions & 1 deletion pkg/smokescreen/config_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ type yamlConfig struct {
MaxRequestRate float64 `yaml:"max_request_rate"`
MaxRequestBurst *int `yaml:"max_request_burst"`

DNSTimeout time.Duration `yaml:"dns_timeout"`
// Tunnel limiting (for long-lived CONNECT connections)
MaxConcurrentConnectTunnels int `yaml:"max_concurrent_connect_tunnels"`

DNSTimeout time.Duration `yaml:"dns_timeout"`
}

func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error {
Expand Down Expand Up @@ -226,6 +229,11 @@ func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error {
}
}

// Set tunnel limit for CONNECT connections
if yc.MaxConcurrentConnectTunnels > 0 {
c.MaxConcurrentConnectTunnels = yc.MaxConcurrentConnectTunnels
}

if yc.DNSTimeout > 0 {
c.DNSTimeout = yc.DNSTimeout
}
Expand Down
9 changes: 9 additions & 0 deletions pkg/smokescreen/conntrack/instrumented_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ type InstrumentedConn struct {

closed bool
CloseError error

// OnClose is called when the connection is closed.
// This can be used to release resources like tunnel limiter slots.
OnClose func()
}

func (t *Tracker) NewInstrumentedConnWithTimeout(conn net.Conn, timeout time.Duration, logger *logrus.Entry, role, outboundHost, proxyType, project string) *InstrumentedConn {
Expand Down Expand Up @@ -91,6 +95,11 @@ func (ic *InstrumentedConn) Close() error {
ic.closed = true
ic.tracker.Delete(ic)

// Call OnClose callback if set (e.g., to release tunnel limiter slot)
if ic.OnClose != nil {
ic.OnClose()
}

end := time.Now()
duration := end.Sub(ic.Start).Seconds()

Expand Down
47 changes: 42 additions & 5 deletions pkg/smokescreen/smokescreen.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ type SmokescreenContext struct {
// This is an explicit flag that ensures role reuse only for CONNECT MITM requests
// and prevents attacks that try to reuse the role for traditional HTTP proxy requests.
isConnectMitm bool
// tunnelSlotAcquired indicates whether a tunnel limiter slot was acquired for this connection.
// If true, the slot must be released when the connection closes.
tunnelSlotAcquired bool
}

// ExitStatus is used to log Smokescreen's connection status at shutdown time
Expand Down Expand Up @@ -495,6 +498,15 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
// cost incurred by Smokescreen.
sctx.cfg.MetricsClient.Timing("proxy_duration_ms", time.Since(sctx.start), 1)

// For CONNECT tunnels, acquire a tunnel limiter slot before dialing.
// Unlike the rate limiter, this tracks actual long-lived tunnel connections.
if sctx.ProxyType == connectProxy && sctx.cfg.TunnelLimiter != nil {
if !sctx.cfg.TunnelLimiter.Acquire() {
return nil, denyError{ErrTunnelLimitExceeded}
}
sctx.tunnelSlotAcquired = true
}

var conn net.Conn
var err error

Expand All @@ -515,6 +527,13 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
sctx.cfg.MetricsClient.IncrWithTags("cn.atpt.total", map[string]string{"success": "false"}, 1)
sctx.cfg.ConnTracker.RecordAttempt(sctx.RequestedHost, false)
metrics.ReportConnError(sctx.cfg.MetricsClient, err)

// Release tunnel slot if acquired since connection failed
if sctx.tunnelSlotAcquired && sctx.cfg.TunnelLimiter != nil {
sctx.cfg.TunnelLimiter.Release()
sctx.tunnelSlotAcquired = false
}

return nil, err
}
sctx.cfg.MetricsClient.IncrWithTags("cn.atpt.total", map[string]string{"success": "true"}, 1)
Expand All @@ -525,6 +544,16 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if sctx.ProxyType == connectProxy {
ic := sctx.cfg.ConnTracker.NewInstrumentedConnWithTimeout(conn, sctx.cfg.IdleTimeout, sctx.Logger, d.Role, d.OutboundHost, sctx.ProxyType, d.Project)
pctx.ConnErrorHandler = ic.Error

// Set up tunnel limiter release callback.
// The slot was acquired above right before dialing, and must be released when the connection closes.
if sctx.tunnelSlotAcquired && sctx.cfg.TunnelLimiter != nil {
ic.OnClose = func() {
sctx.cfg.TunnelLimiter.Release()
sctx.tunnelSlotAcquired = false
}
}

conn = ic
} else {
conn = NewTimeoutConn(conn, sctx.cfg.IdleTimeout)
Expand Down Expand Up @@ -719,7 +748,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
existingSctx, ok := pctx.UserData.(*SmokescreenContext)
if !ok || existingSctx.Decision == nil {
config.Log.WithFields(logrus.Fields{
"context_valid": ok,
"context_valid": ok,
"decision_present": existingSctx != nil && existingSctx.Decision != nil,
}).Error("MITM request missing required context or decision from CONNECT phase - rejecting request")
err := errors.New("MITM request missing context from CONNECT phase")
Expand Down Expand Up @@ -1021,6 +1050,14 @@ func StartWithConfig(config *Config, quit <-chan interface{}) {
config.MaxConcurrentRequests, config.MaxRequestRate, config.MaxRequestBurst)
}

if config.MaxConcurrentConnectTunnels > 0 {
config.initializeTunnelLimiter()
if config.TunnelLimiter != nil {
config.Log.Printf("CONNECT tunnel limiting enabled (max_concurrent_tunnels=%d)",
config.MaxConcurrentConnectTunnels)
}
}

if config.Healthcheck != nil {
handler = &HealthcheckMiddleware{
Proxy: handler,
Expand Down Expand Up @@ -1274,21 +1311,21 @@ func checkACLsForRequest(config *Config, sctx *SmokescreenContext, req *http.Req

var role string
var roleErr error

// Check if role is already populated in SmokescreenContext (e.g., from CONNECT in MITM mode)
if sctx.isConnectMitm {
if sctx.Decision == nil || sctx.Decision.Role == "" {
config.Log.WithFields(logrus.Fields{
"decision_nil": sctx.Decision == nil,
"role_empty": sctx.Decision != nil && sctx.Decision.Role == "",
"decision_nil": sctx.Decision == nil,
"role_empty": sctx.Decision != nil && sctx.Decision.Role == "",
}).Error("MITM request missing required role from CONNECT phase")
config.MetricsClient.Incr("acl.role_not_determined", 1)
decision.Reason = "Client role cannot be determined"
return decision
}
role = sctx.Decision.Role
config.Log.WithFields(logrus.Fields{
"role": role,
"role": role,
"destination": destination.String(),
}).Info("Reusing existing role from context (MITM)")
} else {
Expand Down
Loading