diff --git a/cmd/smokescreen.go b/cmd/smokescreen.go index ce979345..1bb6ddb8 100644 --- a/cmd/smokescreen.go +++ b/cmd/smokescreen.go @@ -170,6 +170,11 @@ func NewConfiguration(args []string, logger *log.Logger) (*smokescreen.Config, e Value: smokescreen.DefaultMaxRequestBurst, Usage: "Maximum burst capacity for rate limiting.\n\t\tMust be greater than max-request-rate when specified.\n\t\tOmit to use default (2x max-request-rate).", }, + cli.IntFlag{ + Name: "max-concurrent-connect-tunnels", + Value: smokescreen.DefaultMaxConcurrentConnectTunnels, + Usage: "Maximum number of concurrent CONNECT tunnels.\n\t\tUnlike max-concurrent-requests, this limits actual long-lived connections.\n\t\t0 = unlimited (default).", + }, cli.DurationFlag{ Name: "dns-timeout", Value: smokescreen.DefaultDNSTimeout, @@ -333,6 +338,11 @@ func NewConfiguration(args []string, logger *log.Logger) (*smokescreen.Config, e } } + if c.IsSet("max-concurrent-connect-tunnels") { + maxTunnels := c.Int("max-concurrent-connect-tunnels") + conf.MaxConcurrentConnectTunnels = maxTunnels + } + if c.IsSet("dns-timeout") { conf.DNSTimeout = c.Duration("dns-timeout") } diff --git a/pkg/smokescreen/config.go b/pkg/smokescreen/config.go index 68388e26..1c5a8103 100644 --- a/pkg/smokescreen/config.go +++ b/pkg/smokescreen/config.go @@ -29,11 +29,11 @@ import ( // Configuration defaults const ( // Server defaults - DefaultPort uint16 = 4750 - DefaultConnectTimeout = 10 * time.Second - DefaultExitTimeout = 500 * time.Minute - DefaultNetwork = "ip" - DefaultStatsSocketFileMode = 0700 + DefaultPort uint16 = 4750 + DefaultConnectTimeout = 10 * time.Second + DefaultExitTimeout = 500 * time.Minute + DefaultNetwork = "ip" + DefaultStatsSocketFileMode = 0700 // HTTP server timeouts DefaultReadHeaderTimeout = 300 * time.Second @@ -52,9 +52,10 @@ const ( DefaultStatsdAddress = "127.0.0.1:8200" // Rate limiting defaults - DefaultMaxConcurrentRequests = 0 // 0 = unlimited - DefaultMaxRequestRate = 0.0 // 0 = unlimited - DefaultMaxRequestBurst = -1 // -1 = use 2x rate + DefaultMaxConcurrentRequests = 0 // 0 = unlimited + DefaultMaxRequestRate = 0.0 // 0 = unlimited + DefaultMaxRequestBurst = -1 // -1 = use 2x rate + DefaultMaxConcurrentConnectTunnels = 0 // 0 = unlimited ) type RuleRange struct { @@ -178,6 +179,15 @@ type Config struct { // Set to 0 to use default (2x MaxRequestRate). MaxRequestBurst int + // MaxConcurrentConnectTunnels limits the number of active CONNECT tunnels. + // Unlike MaxConcurrentRequests which only limits request processing time, + // this limits the actual number of long-lived tunnel connections. + // Set to 0 to disable tunnel limiting. + MaxConcurrentConnectTunnels int + + // TunnelLimiter is used internally to enforce MaxConcurrentConnectTunnels. + TunnelLimiter *TunnelLimiter + // DNSTimeout is the maximum time to wait for DNS resolution. // Set to 0 to use default (5 seconds). DNSTimeout time.Duration @@ -343,22 +353,30 @@ func (config *Config) SetRateLimits(maxConcurrent int, maxRate float64, maxReque if maxRate < 0 { return fmt.Errorf("maxRate must be >= 0, got %.2f", maxRate) } - + if maxRequestBurst >= 0 && maxRate > 0 && float64(maxRequestBurst) <= maxRate { return fmt.Errorf("maxRequestBurst (%d) must be greater than maxRequestRate (%.2f); omit to use default (2x rate)", maxRequestBurst, maxRate) } - + // Apply default: 2x rate when not explicitly configured or configured negative if maxRequestBurst < 0 { maxRequestBurst = int(maxRate * 2) } - + config.MaxConcurrentRequests = maxConcurrent config.MaxRequestRate = maxRate config.MaxRequestBurst = maxRequestBurst return nil } +// initializeTunnelLimiter creates the TunnelLimiter if MaxConcurrentConnectTunnels +// is set but the limiter hasn't been initialized yet. +func (config *Config) initializeTunnelLimiter() { + if config.MaxConcurrentConnectTunnels > 0 && config.TunnelLimiter == nil { + config.TunnelLimiter = NewTunnelLimiter(config.MaxConcurrentConnectTunnels, config) + } +} + // RFC 5280, 4.2.1.1 type authKeyId struct { Id []byte `asn1:"optional,tag:0"` diff --git a/pkg/smokescreen/config_loader.go b/pkg/smokescreen/config_loader.go index 07c9077c..b8576f51 100644 --- a/pkg/smokescreen/config_loader.go +++ b/pkg/smokescreen/config_loader.go @@ -67,7 +67,10 @@ type yamlConfig struct { MaxRequestRate float64 `yaml:"max_request_rate"` MaxRequestBurst *int `yaml:"max_request_burst"` - DNSTimeout time.Duration `yaml:"dns_timeout"` + // Tunnel limiting (for long-lived CONNECT connections) + MaxConcurrentConnectTunnels int `yaml:"max_concurrent_connect_tunnels"` + + DNSTimeout time.Duration `yaml:"dns_timeout"` } func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { @@ -226,6 +229,11 @@ func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { } } + // Set tunnel limit for CONNECT connections + if yc.MaxConcurrentConnectTunnels > 0 { + c.MaxConcurrentConnectTunnels = yc.MaxConcurrentConnectTunnels + } + if yc.DNSTimeout > 0 { c.DNSTimeout = yc.DNSTimeout } diff --git a/pkg/smokescreen/conntrack/instrumented_conn.go b/pkg/smokescreen/conntrack/instrumented_conn.go index a8ca0783..5cb6a482 100644 --- a/pkg/smokescreen/conntrack/instrumented_conn.go +++ b/pkg/smokescreen/conntrack/instrumented_conn.go @@ -43,6 +43,10 @@ type InstrumentedConn struct { closed bool CloseError error + + // OnClose is called when the connection is closed. + // This can be used to release resources like tunnel limiter slots. + OnClose func() } func (t *Tracker) NewInstrumentedConnWithTimeout(conn net.Conn, timeout time.Duration, logger *logrus.Entry, role, outboundHost, proxyType, project string) *InstrumentedConn { @@ -91,6 +95,11 @@ func (ic *InstrumentedConn) Close() error { ic.closed = true ic.tracker.Delete(ic) + // Call OnClose callback if set (e.g., to release tunnel limiter slot) + if ic.OnClose != nil { + ic.OnClose() + } + end := time.Now() duration := end.Sub(ic.Start).Seconds() diff --git a/pkg/smokescreen/smokescreen.go b/pkg/smokescreen/smokescreen.go index 73a59f24..e34f945e 100644 --- a/pkg/smokescreen/smokescreen.go +++ b/pkg/smokescreen/smokescreen.go @@ -92,6 +92,9 @@ type SmokescreenContext struct { // This is an explicit flag that ensures role reuse only for CONNECT MITM requests // and prevents attacks that try to reuse the role for traditional HTTP proxy requests. isConnectMitm bool + // tunnelSlotAcquired indicates whether a tunnel limiter slot was acquired for this connection. + // If true, the slot must be released when the connection closes. + tunnelSlotAcquired bool } // ExitStatus is used to log Smokescreen's connection status at shutdown time @@ -495,6 +498,15 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { // cost incurred by Smokescreen. sctx.cfg.MetricsClient.Timing("proxy_duration_ms", time.Since(sctx.start), 1) + // For CONNECT tunnels, acquire a tunnel limiter slot before dialing. + // Unlike the rate limiter, this tracks actual long-lived tunnel connections. + if sctx.ProxyType == connectProxy && sctx.cfg.TunnelLimiter != nil { + if !sctx.cfg.TunnelLimiter.Acquire() { + return nil, denyError{ErrTunnelLimitExceeded} + } + sctx.tunnelSlotAcquired = true + } + var conn net.Conn var err error @@ -515,6 +527,13 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { sctx.cfg.MetricsClient.IncrWithTags("cn.atpt.total", map[string]string{"success": "false"}, 1) sctx.cfg.ConnTracker.RecordAttempt(sctx.RequestedHost, false) metrics.ReportConnError(sctx.cfg.MetricsClient, err) + + // Release tunnel slot if acquired since connection failed + if sctx.tunnelSlotAcquired && sctx.cfg.TunnelLimiter != nil { + sctx.cfg.TunnelLimiter.Release() + sctx.tunnelSlotAcquired = false + } + return nil, err } sctx.cfg.MetricsClient.IncrWithTags("cn.atpt.total", map[string]string{"success": "true"}, 1) @@ -525,6 +544,16 @@ func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { if sctx.ProxyType == connectProxy { ic := sctx.cfg.ConnTracker.NewInstrumentedConnWithTimeout(conn, sctx.cfg.IdleTimeout, sctx.Logger, d.Role, d.OutboundHost, sctx.ProxyType, d.Project) pctx.ConnErrorHandler = ic.Error + + // Set up tunnel limiter release callback. + // The slot was acquired above right before dialing, and must be released when the connection closes. + if sctx.tunnelSlotAcquired && sctx.cfg.TunnelLimiter != nil { + ic.OnClose = func() { + sctx.cfg.TunnelLimiter.Release() + sctx.tunnelSlotAcquired = false + } + } + conn = ic } else { conn = NewTimeoutConn(conn, sctx.cfg.IdleTimeout) @@ -719,7 +748,7 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { existingSctx, ok := pctx.UserData.(*SmokescreenContext) if !ok || existingSctx.Decision == nil { config.Log.WithFields(logrus.Fields{ - "context_valid": ok, + "context_valid": ok, "decision_present": existingSctx != nil && existingSctx.Decision != nil, }).Error("MITM request missing required context or decision from CONNECT phase - rejecting request") err := errors.New("MITM request missing context from CONNECT phase") @@ -1021,6 +1050,14 @@ func StartWithConfig(config *Config, quit <-chan interface{}) { config.MaxConcurrentRequests, config.MaxRequestRate, config.MaxRequestBurst) } + if config.MaxConcurrentConnectTunnels > 0 { + config.initializeTunnelLimiter() + if config.TunnelLimiter != nil { + config.Log.Printf("CONNECT tunnel limiting enabled (max_concurrent_tunnels=%d)", + config.MaxConcurrentConnectTunnels) + } + } + if config.Healthcheck != nil { handler = &HealthcheckMiddleware{ Proxy: handler, @@ -1274,13 +1311,13 @@ func checkACLsForRequest(config *Config, sctx *SmokescreenContext, req *http.Req var role string var roleErr error - + // Check if role is already populated in SmokescreenContext (e.g., from CONNECT in MITM mode) if sctx.isConnectMitm { if sctx.Decision == nil || sctx.Decision.Role == "" { config.Log.WithFields(logrus.Fields{ - "decision_nil": sctx.Decision == nil, - "role_empty": sctx.Decision != nil && sctx.Decision.Role == "", + "decision_nil": sctx.Decision == nil, + "role_empty": sctx.Decision != nil && sctx.Decision.Role == "", }).Error("MITM request missing required role from CONNECT phase") config.MetricsClient.Incr("acl.role_not_determined", 1) decision.Reason = "Client role cannot be determined" @@ -1288,7 +1325,7 @@ func checkACLsForRequest(config *Config, sctx *SmokescreenContext, req *http.Req } role = sctx.Decision.Role config.Log.WithFields(logrus.Fields{ - "role": role, + "role": role, "destination": destination.String(), }).Info("Reusing existing role from context (MITM)") } else { diff --git a/pkg/smokescreen/smokescreen_test.go b/pkg/smokescreen/smokescreen_test.go index 2ac82e4e..418574e1 100644 --- a/pkg/smokescreen/smokescreen_test.go +++ b/pkg/smokescreen/smokescreen_test.go @@ -17,6 +17,7 @@ import ( "net/http/httptest" "net/url" "strings" + "sync" "sync/atomic" "testing" "time" @@ -2096,3 +2097,158 @@ func TestRoleLoggingInCanonicalProxyDecision(t *testing.T) { r.Equal("connect", proxyDecision.Data["proxy_type"]) }) } + +func TestMaxConcurrentConnectTunnels(t *testing.T) { + r := require.New(t) + + // Create a slow backend that holds connections open + slowBackend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Hold the connection open for a bit + time.Sleep(500 * time.Millisecond) + w.WriteHeader(200) + w.Write([]byte("OK")) + })) + defer slowBackend.Close() + + t.Run("tunnel limit enforced", func(t *testing.T) { + cfg, err := testConfig("test-local-srv") + r.NoError(err) + err = cfg.SetAllowAddresses([]string{"127.0.0.1"}) + r.NoError(err) + + // Set tunnel limit to 2 + cfg.MaxConcurrentConnectTunnels = 2 + cfg.TunnelLimiter = NewTunnelLimiter(2, cfg) + + l, err := net.Listen("tcp", "localhost:0") + r.NoError(err) + cfg.Listener = l + + proxySrv := proxyServer(cfg) + defer proxySrv.Close() + + // Track results + var successCount, rejectCount int32 + var wg sync.WaitGroup + + // Try to open 5 concurrent CONNECT tunnels + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + client, err := proxyClient(proxySrv.URL) + if err != nil { + t.Logf("Failed to create client: %v", err) + return + } + client.Timeout = 2 * time.Second + + resp, err := client.Get(slowBackend.URL) + if err != nil { + // Check if it's a tunnel limit error or rejection + if strings.Contains(err.Error(), "Service Unavailable") || + strings.Contains(err.Error(), "maximum concurrent connect tunnels") || + strings.Contains(err.Error(), "Request rejected by proxy") { + atomic.AddInt32(&rejectCount, 1) + } else { + t.Logf("Request error: %v", err) + } + return + } + defer resp.Body.Close() + + if resp.StatusCode == 200 { + atomic.AddInt32(&successCount, 1) + } else if resp.StatusCode == 503 || resp.StatusCode == 407 { + atomic.AddInt32(&rejectCount, 1) + } + }() + } + + wg.Wait() + + // With limit of 2, we expect at most 2 successes + // Some requests should be rejected + t.Logf("Success: %d, Rejected: %d", successCount, rejectCount) + r.LessOrEqual(int(successCount), 2, "Should not exceed tunnel limit") + r.Greater(int(rejectCount), 0, "Some requests should be rejected") + }) + + t.Run("tunnel slots released on close", func(t *testing.T) { + cfg, err := testConfig("test-local-srv") + r.NoError(err) + err = cfg.SetAllowAddresses([]string{"127.0.0.1"}) + r.NoError(err) + + // Set tunnel limit to 1 + cfg.MaxConcurrentConnectTunnels = 1 + cfg.TunnelLimiter = NewTunnelLimiter(1, cfg) + + l, err := net.Listen("tcp", "localhost:0") + r.NoError(err) + cfg.Listener = l + + proxySrv := proxyServer(cfg) + defer proxySrv.Close() + + // Quick backend that responds immediately + quickBackend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(200) + w.Write([]byte("OK")) + })) + defer quickBackend.Close() + + // Make sequential requests - they should all succeed because + // each one completes before the next starts + for i := 0; i < 3; i++ { + client, err := proxyClient(proxySrv.URL) + r.NoError(err) + client.Timeout = 5 * time.Second + + resp, err := client.Get(quickBackend.URL) + r.NoError(err, "Request %d should succeed", i) + resp.Body.Close() + r.Equal(200, resp.StatusCode, "Request %d should return 200", i) + + // Wait for connection to fully close + cfg.ConnTracker.Wg().Wait() + } + }) + + t.Run("zero limit means unlimited", func(t *testing.T) { + cfg, err := testConfig("test-local-srv") + r.NoError(err) + err = cfg.SetAllowAddresses([]string{"127.0.0.1"}) + r.NoError(err) + + // Set tunnel limit to 0 (unlimited) + cfg.MaxConcurrentConnectTunnels = 0 + cfg.TunnelLimiter = nil // No limiter when unlimited + + l, err := net.Listen("tcp", "localhost:0") + r.NoError(err) + cfg.Listener = l + + proxySrv := proxyServer(cfg) + defer proxySrv.Close() + + // Quick backend + quickBackend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(200) + w.Write([]byte("OK")) + })) + defer quickBackend.Close() + + // Should be able to make many requests + for i := 0; i < 5; i++ { + client, err := proxyClient(proxySrv.URL) + r.NoError(err) + + resp, err := client.Get(quickBackend.URL) + r.NoError(err, "Request %d should succeed with no limit", i) + resp.Body.Close() + r.Equal(200, resp.StatusCode) + } + }) +} diff --git a/pkg/smokescreen/tunnel_limiter.go b/pkg/smokescreen/tunnel_limiter.go new file mode 100644 index 00000000..e3e2bdd4 --- /dev/null +++ b/pkg/smokescreen/tunnel_limiter.go @@ -0,0 +1,72 @@ +package smokescreen + +import ( + "errors" + "sync/atomic" +) + +// ErrTunnelLimitExceeded is returned when the maximum number of concurrent tunnels is reached. +var ErrTunnelLimitExceeded = errors.New("maximum concurrent connect tunnels exceeded") + +// TunnelLimiter limits the number of concurrent CONNECT tunnels. +// Unlike the rate limiter which counts request processing time, +// this tracks actual long-lived tunnel connections. +type TunnelLimiter struct { + maxTunnels int64 + activeTunnels int64 + config *Config +} + +// NewTunnelLimiter creates a new tunnel limiter with the specified maximum. +// If max is 0 or negative, limiting is disabled. +func NewTunnelLimiter(max int, config *Config) *TunnelLimiter { + return &TunnelLimiter{ + maxTunnels: int64(max), + config: config, + } +} + +// Acquire attempts to acquire a tunnel slot. +// Returns true if successful, false if at capacity. +func (tl *TunnelLimiter) Acquire() bool { + if tl == nil || tl.maxTunnels <= 0 { + return true // Limiting disabled + } + + for { + current := atomic.LoadInt64(&tl.activeTunnels) + if current >= tl.maxTunnels { + if tl.config != nil && tl.config.MetricsClient != nil { + tl.config.MetricsClient.Incr("tunnels.concurrency_limited", 1) + } + return false + } + if atomic.CompareAndSwapInt64(&tl.activeTunnels, current, current+1) { + return true + } + } +} + +// Release releases a tunnel slot. +func (tl *TunnelLimiter) Release() { + if tl == nil || tl.maxTunnels <= 0 { + return + } + atomic.AddInt64(&tl.activeTunnels, -1) +} + +// ActiveCount returns the current number of active tunnels. +func (tl *TunnelLimiter) ActiveCount() int64 { + if tl == nil { + return 0 + } + return atomic.LoadInt64(&tl.activeTunnels) +} + +// MaxTunnels returns the maximum number of allowed tunnels. +func (tl *TunnelLimiter) MaxTunnels() int64 { + if tl == nil { + return 0 + } + return tl.maxTunnels +} diff --git a/pkg/smokescreen/tunnel_limiter_test.go b/pkg/smokescreen/tunnel_limiter_test.go new file mode 100644 index 00000000..24ac4772 --- /dev/null +++ b/pkg/smokescreen/tunnel_limiter_test.go @@ -0,0 +1,101 @@ +package smokescreen + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTunnelLimiter_NilLimiter(t *testing.T) { + var tl *TunnelLimiter + + // Nil limiter should always allow + assert.True(t, tl.Acquire()) + tl.Release() // Should not panic + assert.Equal(t, int64(0), tl.ActiveCount()) + assert.Equal(t, int64(0), tl.MaxTunnels()) +} + +func TestTunnelLimiter_ZeroLimit(t *testing.T) { + tl := NewTunnelLimiter(0, nil) + + // Zero limit means disabled - should always allow + assert.True(t, tl.Acquire()) + assert.True(t, tl.Acquire()) + assert.True(t, tl.Acquire()) + tl.Release() + tl.Release() + tl.Release() +} + +func TestTunnelLimiter_BasicLimit(t *testing.T) { + tl := NewTunnelLimiter(3, nil) + + // Should allow up to 3 + assert.True(t, tl.Acquire()) + assert.Equal(t, int64(1), tl.ActiveCount()) + + assert.True(t, tl.Acquire()) + assert.Equal(t, int64(2), tl.ActiveCount()) + + assert.True(t, tl.Acquire()) + assert.Equal(t, int64(3), tl.ActiveCount()) + + // Fourth should fail + assert.False(t, tl.Acquire()) + assert.Equal(t, int64(3), tl.ActiveCount()) + + // Release one + tl.Release() + assert.Equal(t, int64(2), tl.ActiveCount()) + + // Now should allow one more + assert.True(t, tl.Acquire()) + assert.Equal(t, int64(3), tl.ActiveCount()) + + // Cleanup + tl.Release() + tl.Release() + tl.Release() + assert.Equal(t, int64(0), tl.ActiveCount()) +} + +func TestTunnelLimiter_ConcurrentAccess(t *testing.T) { + const limit = 10 + const goroutines = 100 + + tl := NewTunnelLimiter(limit, nil) + + var acquired int64 + var wg sync.WaitGroup + + // Try to acquire from many goroutines simultaneously + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if tl.Acquire() { + atomic.AddInt64(&acquired, 1) + } + }() + } + + wg.Wait() + + // Should have acquired exactly 'limit' slots + assert.Equal(t, int64(limit), acquired) + assert.Equal(t, int64(limit), tl.ActiveCount()) + + // Release all + for i := 0; i < limit; i++ { + tl.Release() + } + assert.Equal(t, int64(0), tl.ActiveCount()) +} + +func TestTunnelLimiter_MaxTunnels(t *testing.T) { + tl := NewTunnelLimiter(42, nil) + assert.Equal(t, int64(42), tl.MaxTunnels()) +}