diff --git a/docs/swagger/docs.go b/docs/swagger/docs.go index da4b65ec1..bddbb9c6e 100644 --- a/docs/swagger/docs.go +++ b/docs/swagger/docs.go @@ -8694,6 +8694,14 @@ const docTemplate = `{ "type": "string" } }, + "circuit_breaker_threshold": { + "description": "consecutive failures to trip open (0=disabled)", + "type": "integer" + }, + "circuit_breaker_timeout": { + "description": "ms in open before half-open", + "type": "integer" + }, "created_at": { "type": "string" }, @@ -8712,6 +8720,10 @@ const docTemplate = `{ "description": "Max request body size in bytes", "type": "integer" }, + "max_retries": { + "description": "max retry attempts (0=disabled)", + "type": "integer" + }, "methods": { "description": "New: HTTP methods to match (empty = all)", "type": "array", @@ -8757,6 +8769,10 @@ const docTemplate = `{ "description": "Time to receive headers in milliseconds", "type": "integer" }, + "retry_timeout": { + "description": "total retry window in ms", + "type": "integer" + }, "strip_prefix": { "description": "If true, removes path_prefix from request before forwarding", "type": "boolean" @@ -11102,6 +11118,14 @@ const docTemplate = `{ "type": "string" } }, + "circuit_breaker_threshold": { + "type": "integer", + "minimum": 0 + }, + "circuit_breaker_timeout": { + "type": "integer", + "minimum": 0 + }, "dial_timeout": { "type": "integer", "minimum": 0 @@ -11114,6 +11138,10 @@ const docTemplate = `{ "type": "integer", "minimum": 0 }, + "max_retries": { + "type": "integer", + "minimum": 0 + }, "methods": { "type": "array", "items": { @@ -11141,6 +11169,10 @@ const docTemplate = `{ "type": "integer", "minimum": 0 }, + "retry_timeout": { + "type": "integer", + "minimum": 0 + }, "strip_prefix": { "type": "boolean" }, diff --git a/docs/swagger/swagger.json b/docs/swagger/swagger.json index 3b209944c..2db916de3 100644 --- a/docs/swagger/swagger.json +++ b/docs/swagger/swagger.json @@ -8686,6 +8686,14 @@ "type": "string" } }, + "circuit_breaker_threshold": { + "description": "consecutive failures to trip open (0=disabled)", + "type": "integer" + }, + "circuit_breaker_timeout": { + "description": "ms in open before half-open", + "type": "integer" + }, "created_at": { "type": "string" }, @@ -8704,6 +8712,10 @@ "description": "Max request body size in bytes", "type": "integer" }, + "max_retries": { + "description": "max retry attempts (0=disabled)", + "type": "integer" + }, "methods": { "description": "New: HTTP methods to match (empty = all)", "type": "array", @@ -8749,6 +8761,10 @@ "description": "Time to receive headers in milliseconds", "type": "integer" }, + "retry_timeout": { + "description": "total retry window in ms", + "type": "integer" + }, "strip_prefix": { "description": "If true, removes path_prefix from request before forwarding", "type": "boolean" @@ -11094,6 +11110,14 @@ "type": "string" } }, + "circuit_breaker_threshold": { + "type": "integer", + "minimum": 0 + }, + "circuit_breaker_timeout": { + "type": "integer", + "minimum": 0 + }, "dial_timeout": { "type": "integer", "minimum": 0 @@ -11106,6 +11130,10 @@ "type": "integer", "minimum": 0 }, + "max_retries": { + "type": "integer", + "minimum": 0 + }, "methods": { "type": "array", "items": { @@ -11133,6 +11161,10 @@ "type": "integer", "minimum": 0 }, + "retry_timeout": { + "type": "integer", + "minimum": 0 + }, "strip_prefix": { "type": "boolean" }, diff --git a/docs/swagger/swagger.yaml b/docs/swagger/swagger.yaml index 2602175bc..7bf45ac32 100644 --- a/docs/swagger/swagger.yaml +++ b/docs/swagger/swagger.yaml @@ -411,6 +411,12 @@ definitions: items: type: string type: array + circuit_breaker_threshold: + description: consecutive failures to trip open (0=disabled) + type: integer + circuit_breaker_timeout: + description: ms in open before half-open + type: integer created_at: type: string dial_timeout: @@ -424,6 +430,9 @@ definitions: max_body_size: description: Max request body size in bytes type: integer + max_retries: + description: max retry attempts (0=disabled) + type: integer methods: description: 'New: HTTP methods to match (empty = all)' items: @@ -457,6 +466,9 @@ definitions: response_header_timeout: description: Time to receive headers in milliseconds type: integer + retry_timeout: + description: total retry window in ms + type: integer strip_prefix: description: If true, removes path_prefix from request before forwarding type: boolean @@ -2147,6 +2159,12 @@ definitions: items: type: string type: array + circuit_breaker_threshold: + minimum: 0 + type: integer + circuit_breaker_timeout: + minimum: 0 + type: integer dial_timeout: minimum: 0 type: integer @@ -2156,6 +2174,9 @@ definitions: max_body_size: minimum: 0 type: integer + max_retries: + minimum: 0 + type: integer methods: items: type: string @@ -2175,6 +2196,9 @@ definitions: response_header_timeout: minimum: 0 type: integer + retry_timeout: + minimum: 0 + type: integer strip_prefix: type: boolean target_url: diff --git a/internal/core/domain/gateway.go b/internal/core/domain/gateway.go index 5a2a98e96..3087d6524 100644 --- a/internal/core/domain/gateway.go +++ b/internal/core/domain/gateway.go @@ -32,6 +32,10 @@ type GatewayRoute struct { BlockedCIDRs []string `json:"blocked_cidrs,omitempty"` // IPs blocked from access BlockedIPNets []*net.IPNet `json:"-"` // pre-parsed at creation/refresh for fast lookup MaxBodySize int64 `json:"max_body_size,omitempty"` // Max request body size in bytes + CircuitBreakerThreshold int `json:"circuit_breaker_threshold,omitempty"` // consecutive failures to trip open (0=disabled) + CircuitBreakerTimeout int64 `json:"circuit_breaker_timeout,omitempty"` // ms in open before half-open + MaxRetries int `json:"max_retries,omitempty"` // max retry attempts (0=disabled) + RetryTimeout int64 `json:"retry_timeout,omitempty"` // total retry window in ms Priority int `json:"priority"` // Manual priority for tie-breaking CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` diff --git a/internal/core/ports/gateway.go b/internal/core/ports/gateway.go index 43e1643dd..84357c7a7 100644 --- a/internal/core/ports/gateway.go +++ b/internal/core/ports/gateway.go @@ -41,8 +41,12 @@ type CreateRouteParams struct { RequireTLS bool AllowedCIDRs []string BlockedCIDRs []string - MaxBodySize int64 - Priority int + MaxBodySize int64 + CircuitBreakerThreshold int + CircuitBreakerTimeout int64 + MaxRetries int + RetryTimeout int64 + Priority int } // GatewayService provides business logic for managing the API gateway and ingress traffic. diff --git a/internal/core/services/gateway.go b/internal/core/services/gateway.go index abc22eee3..32d4234b5 100644 --- a/internal/core/services/gateway.go +++ b/internal/core/services/gateway.go @@ -5,7 +5,11 @@ import ( "context" "crypto/tls" "fmt" + "io" "log/slog" + "crypto/rand" + "encoding/binary" + "math" "net" "net/http" "net/http/httputil" @@ -20,6 +24,7 @@ import ( "github.com/poyrazk/thecloud/internal/core/domain" "github.com/poyrazk/thecloud/internal/core/ports" "github.com/poyrazk/thecloud/internal/errors" + "github.com/poyrazk/thecloud/internal/platform" "github.com/poyrazk/thecloud/internal/routing" ) @@ -94,11 +99,29 @@ func (s *GatewayService) CreateRoute(ctx context.Context, params ports.CreateRou AllowedCIDRs: params.AllowedCIDRs, BlockedCIDRs: params.BlockedCIDRs, MaxBodySize: params.MaxBodySize, + CircuitBreakerThreshold: params.CircuitBreakerThreshold, + CircuitBreakerTimeout: params.CircuitBreakerTimeout, + MaxRetries: params.MaxRetries, + RetryTimeout: params.RetryTimeout, Priority: params.Priority, CreatedAt: time.Now(), UpdatedAt: time.Now(), } + // Apply default values for resilience parameters + if route.CircuitBreakerThreshold == 0 { + route.CircuitBreakerThreshold = 5 + } + if route.CircuitBreakerTimeout == 0 { + route.CircuitBreakerTimeout = 30000 // ms + } + if route.MaxRetries == 0 { + route.MaxRetries = 2 + } + if route.RetryTimeout == 0 { + route.RetryTimeout = 5000 // ms + } + // Validate CIDRs before saving for _, cidr := range route.AllowedCIDRs { if _, _, err := net.ParseCIDR(cidr); err != nil { @@ -243,7 +266,7 @@ func (s *GatewayService) createReverseProxy(route *domain.GatewayRoute) (*httput idleConnTimeout = 90 * time.Second } - proxy.Transport = &http.Transport{ + baseTransport := &http.Transport{ DialContext: (&net.Dialer{ Timeout: dialTimeout, KeepAlive: 30 * time.Second, @@ -254,6 +277,8 @@ func (s *GatewayService) createReverseProxy(route *domain.GatewayRoute) (*httput TLSHandshakeTimeout: 10 * time.Second, } + proxy.Transport = newRetryTransport(baseTransport, route, s.logger) + originalDirector := proxy.Director proxy.Director = func(req *http.Request) { if route.StripPrefix { @@ -375,3 +400,151 @@ func calculateMatchScore(route *domain.GatewayRoute, _ string) int { return score } + +// retryTransport wraps an http.Transport with circuit breaker and retry logic. +type retryTransport struct { + base http.RoundTripper + cb *platform.CircuitBreaker // nil if circuit breaker is disabled + maxRetries int + retryTimeout time.Duration + logger *slog.Logger +} + +// newRetryTransport wraps a base http.Transport with per-route retry and circuit breaker behavior. +func newRetryTransport(base http.RoundTripper, route *domain.GatewayRoute, logger *slog.Logger) *retryTransport { + rt := &retryTransport{ + base: base, + maxRetries: route.MaxRetries, + retryTimeout: time.Duration(route.RetryTimeout) * time.Millisecond, + logger: logger, + } + if route.CircuitBreakerThreshold > 0 { + rt.cb = platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: route.ID.String(), + Threshold: route.CircuitBreakerThreshold, + ResetTimeout: time.Duration(route.CircuitBreakerTimeout) * time.Millisecond, + OnStateChange: func(name string, from, to platform.State) { + if logger != nil { + logger.Warn("circuit breaker state change", + "route_id", name, + "from", from.String(), + "to", to.String()) + } + }, + }) + } + return rt +} + +// RoundTrip implements http.RoundTripper. +func (rt *retryTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if rt.cb == nil { + return rt.doRoundTrip(req) + } + + type result struct { + resp *http.Response + err error + } + var r result + cbErr := rt.cb.Execute(func() error { + r.resp, r.err = rt.doRoundTrip(req) + return r.err + }) + if cbErr != nil { + if r.resp != nil { + _, _ = io.Copy(io.Discard, r.resp.Body) + _ = r.resp.Body.Close() + } + return nil, cbErr + } + return r.resp, r.err +} + +func (rt *retryTransport) doRoundTrip(req *http.Request) (*http.Response, error) { + if rt.maxRetries <= 0 || !rt.isIdempotent(req.Method) { + return rt.base.RoundTrip(req) + } + + var lastResp *http.Response + var lastErr error + maxAttempts := rt.maxRetries + 1 // first attempt + retries + + for attempt := 0; attempt < maxAttempts; attempt++ { + if attempt > 0 { + delay := rt.backoffWithJitter(attempt) + select { + case <-req.Context().Done(): + return nil, req.Context().Err() + case <-time.After(delay): + } + } + + resp, err := rt.base.RoundTrip(req) + if err == nil { + if !rt.isRetryableStatus(resp.StatusCode) { + return resp, nil + } + // drain and close body so connection can be reused, then retry + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + lastResp = resp + continue + } + + if !rt.isRetryableError(err) { + return nil, err + } + lastErr = err + lastResp = resp + } + return lastResp, lastErr +} + +func (rt *retryTransport) isRetryableStatus(code int) bool { + return code == 502 || code == 503 || code == 504 || code == 429 +} + +func (rt *retryTransport) isRetryableError(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "connection refused") || + strings.Contains(msg, "timeout") || + strings.Contains(msg, "reset by peer") || + strings.Contains(msg, "broken pipe") || + strings.Contains(msg, "connection reset") +} + +func (rt *retryTransport) isIdempotent(method string) bool { + return method == "GET" || method == "HEAD" || method == "PUT" || + method == "DELETE" || method == "OPTIONS" +} + +func (rt *retryTransport) backoffWithJitter(attempt int) time.Duration { + base := 100 * time.Millisecond + cap := rt.retryTimeout + if cap <= 0 { + cap = 5 * time.Second + } + multiplier := 2.0 + delay := float64(base) * math.Pow(multiplier, float64(attempt-1)) + if delay > float64(cap) { + delay = float64(cap) + } + jitter := rt.cryptoJitter(time.Duration(delay)) + return jitter +} + +// cryptoJitter returns a random duration in [0, max) using crypto/rand. +// frac is in [0, 1) so result is always non-negative and strictly bounded by max. +func (rt *retryTransport) cryptoJitter(max time.Duration) time.Duration { + var buf [8]byte + if _, err := rand.Read(buf[:]); err != nil { + return max / 2 // deterministic fallback on crypto rand failure + } + val := binary.BigEndian.Uint64(buf[:]) + frac := float64(val) / float64(math.MaxUint64) + return time.Duration(float64(max) * frac) +} diff --git a/internal/core/services/gateway_retry_test.go b/internal/core/services/gateway_retry_test.go new file mode 100644 index 000000000..bbded6ec5 --- /dev/null +++ b/internal/core/services/gateway_retry_test.go @@ -0,0 +1,376 @@ +package services + +import ( + "errors" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/poyrazk/thecloud/internal/core/domain" + "github.com/poyrazk/thecloud/internal/platform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockRT is a simple http.RoundTripper for testing retry behavior. +type mockRT struct { + results []mockRTResult + callIdx int + calls int +} + +type mockRTResult struct { + resp *http.Response + err error +} + +func (m *mockRT) RoundTrip(_ *http.Request) (*http.Response, error) { + m.calls++ + if m.callIdx >= len(m.results) { + return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader(""))}, nil + } + r := m.results[m.callIdx] + m.callIdx++ + return r.resp, r.err +} + +func mockResp(status int) mockRTResult { + return mockRTResult{resp: &http.Response{StatusCode: status, Body: io.NopCloser(strings.NewReader(""))}} +} + +func mockErr(msg string) mockRTResult { + return mockRTResult{err: errors.New(msg)} +} + +// --- retryTransport helper tests --- + +func TestRetryTransport_IsIdempotent(t *testing.T) { + t.Parallel() + rt := &retryTransport{} + for _, m := range []string{"GET", "HEAD", "PUT", "DELETE", "OPTIONS"} { + assert.True(t, rt.isIdempotent(m), m) + } + for _, m := range []string{"POST", "PATCH", "CONNECT", "TRACE"} { + assert.False(t, rt.isIdempotent(m), m) + } +} + +func TestRetryTransport_IsRetryableStatus(t *testing.T) { + t.Parallel() + rt := &retryTransport{} + for _, c := range []int{502, 503, 504, 429} { + assert.True(t, rt.isRetryableStatus(c), "%d should be retryable", c) + } + for _, c := range []int{200, 201, 400, 401, 403, 404, 500} { + assert.False(t, rt.isRetryableStatus(c), "%d should not be retryable", c) + } +} + +func TestRetryTransport_IsRetryableError(t *testing.T) { + t.Parallel() + rt := &retryTransport{} + retryable := []string{ + "dial tcp: connection refused", + "dial tcp: i/o timeout", + "read tcp: connection reset by peer", + "write tcp: broken pipe", + "read tcp: connection reset", + } + for _, msg := range retryable { + assert.True(t, rt.isRetryableError(errors.New(msg)), msg) + } + nonRetryable := []string{ + "400 bad request", + "401 unauthorized", + "tls: handshake failed", + "server closed connection", + } + for _, msg := range nonRetryable { + assert.False(t, rt.isRetryableError(errors.New(msg)), msg) + } +} + +func TestRetryTransport_BackoffJitter_Bounded(t *testing.T) { + t.Parallel() + rt := &retryTransport{retryTimeout: 5 * time.Second} + for attempt := 1; attempt <= 5; attempt++ { + d := rt.backoffWithJitter(attempt) + assert.Greater(t, d, time.Duration(0), "delay must be > 0") + assert.LessOrEqual(t, d, 5*time.Second, "delay must be <= max") + } +} + +// --- retry loop tests --- + +func TestRetryTransport_DoesNotRetryWhenMaxRetriesZero(t *testing.T) { + t.Parallel() + m := &mockRT{results: []mockRTResult{mockResp(502), mockResp(200)}} + transport := wrapTransport(m, &retryTransport{maxRetries: 0}) + + _, _ = transport.RoundTrip(nil) + if m.results[0].resp != nil { + _ = m.results[0].resp.Body.Close() + } + assert.Equal(t, 1, m.calls, "should call base transport only once") +} + +func TestRetryTransport_DoesNotRetryNonIdempotentPOST(t *testing.T) { + t.Parallel() + m := &mockRT{results: []mockRTResult{mockErr("connection refused"), mockResp(200)}} + transport := wrapTransport(m, &retryTransport{maxRetries: 2}) + + req, _ := http.NewRequest("POST", "/", nil) + _, _ = transport.RoundTrip(req) + if m.results[0].resp != nil { + _ = m.results[0].resp.Body.Close() + } + assert.Equal(t, 1, m.calls, "POST should not be retried") +} + +func TestRetryTransport_DoesNotRetryNonIdempotentPATCH(t *testing.T) { + t.Parallel() + m := &mockRT{results: []mockRTResult{mockErr("connection refused"), mockResp(200)}} + transport := wrapTransport(m, &retryTransport{maxRetries: 2}) + + req, _ := http.NewRequest("PATCH", "/", nil) + _, _ = transport.RoundTrip(req) + assert.Equal(t, 1, m.calls, "PATCH should not be retried") +} + +func TestRetryTransport_RetriesOnConnectionRefused(t *testing.T) { + t.Parallel() + m := &mockRT{results: []mockRTResult{ + mockErr("connection refused"), + mockResp(200), + }} + transport := wrapTransport(m, &retryTransport{maxRetries: 2}) + + req, _ := http.NewRequest("GET", "/", nil) + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, 2, m.calls, "should retry after connection refused") +} + +func TestRetryTransport_RetriesOn502(t *testing.T) { + t.Parallel() + m := &mockRT{results: []mockRTResult{ + mockResp(502), + mockResp(502), + mockResp(200), + }} + transport := wrapTransport(m, &retryTransport{maxRetries: 2}) + + req, _ := http.NewRequest("GET", "/", nil) + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, 3, m.calls, "should retry 502 twice then succeed") +} + +func TestRetryTransport_RetriesOn503(t *testing.T) { + t.Parallel() + m := &mockRT{results: []mockRTResult{ + mockResp(503), + mockResp(200), + }} + transport := wrapTransport(m, &retryTransport{maxRetries: 2}) + + req, _ := http.NewRequest("GET", "/", nil) + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, 2, m.calls) +} + +func TestRetryTransport_RetriesOn429(t *testing.T) { + t.Parallel() + m := &mockRT{results: []mockRTResult{ + mockResp(429), + mockResp(200), + }} + transport := wrapTransport(m, &retryTransport{maxRetries: 2}) + + req, _ := http.NewRequest("GET", "/", nil) + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, 2, m.calls) +} + +func TestRetryTransport_NoRetryOn500(t *testing.T) { + t.Parallel() + m := &mockRT{results: []mockRTResult{ + mockResp(500), + }} + transport := wrapTransport(m, &retryTransport{maxRetries: 2}) + + req, _ := http.NewRequest("GET", "/", nil) + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 500, resp.StatusCode) + assert.Equal(t, 1, m.calls, "500 should not be retried") +} + +func TestRetryTransport_NoRetryOn400(t *testing.T) { + t.Parallel() + m := &mockRT{results: []mockRTResult{ + mockResp(400), + }} + transport := wrapTransport(m, &retryTransport{maxRetries: 2}) + + req, _ := http.NewRequest("GET", "/", nil) + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + assert.Equal(t, 1, m.calls, "400 should not be retried") +} + +func TestRetryTransport_RetriesOnTimeoutError(t *testing.T) { + t.Parallel() + m := &mockRT{results: []mockRTResult{ + mockErr("dial tcp: i/o timeout"), + mockResp(200), + }} + transport := wrapTransport(m, &retryTransport{maxRetries: 2}) + + req, _ := http.NewRequest("GET", "/", nil) + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, 2, m.calls) +} + +func TestRetryTransport_GivesUpAfterMaxRetries(t *testing.T) { + t.Parallel() + m := &mockRT{results: []mockRTResult{ + mockResp(502), + mockResp(502), + mockResp(502), + }} + transport := wrapTransport(m, &retryTransport{maxRetries: 2}) + + req, _ := http.NewRequest("GET", "/", nil) + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 502, resp.StatusCode) + assert.Equal(t, 3, m.calls, "3 attempts: first + 2 retries") +} + +func TestRetryTransport_SucceedsOnFirstAttempt(t *testing.T) { + t.Parallel() + m := &mockRT{results: []mockRTResult{mockResp(200)}} + transport := wrapTransport(m, &retryTransport{maxRetries: 2}) + + req, _ := http.NewRequest("GET", "/", nil) + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, 1, m.calls) +} + +// wrapTransport creates a retryTransport wrapping the mock. +func wrapTransport(mock *mockRT, rt *retryTransport) *retryTransport { + // rt.base is used directly by doRoundTrip — swap it for our mock + rt.base = (*mockHTTPTransport)(mock) + return rt +} + +// mockHTTPTransport lets us inject the mock via rt.base. +type mockHTTPTransport mockRT + +func (m *mockHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return (*mockRT)(m).RoundTrip(req) +} + +func (m *mockHTTPTransport) CloseIdleConnections() {} + +// --- circuit breaker tests --- + +func TestCircuitBreaker_DisabledWhenThresholdZero(t *testing.T) { + t.Parallel() + route := &domain.GatewayRoute{ + ID: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + CircuitBreakerThreshold: 0, + MaxRetries: 2, + RetryTimeout: 5000, + } + rt := newRetryTransport(&http.Transport{}, route, nil) + assert.Nil(t, rt.cb) + assert.Equal(t, 2, rt.maxRetries) +} + +func TestCircuitBreaker_EnabledWhenThresholdPositive(t *testing.T) { + t.Parallel() + route := &domain.GatewayRoute{ + ID: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + CircuitBreakerThreshold: 5, + CircuitBreakerTimeout: 30000, + MaxRetries: 2, + RetryTimeout: 5000, + } + rt := newRetryTransport(&http.Transport{}, route, nil) + assert.NotNil(t, rt.cb) + assert.Equal(t, platform.StateClosed, rt.cb.GetState()) +} + +func TestCircuitBreaker_TripsOpenAfterThreshold(t *testing.T) { + t.Parallel() + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "test", + Threshold: 3, + ResetTimeout: 100 * time.Millisecond, + OnStateChange: nil, + }) + + for i := 0; i < 3; i++ { + _ = cb.Execute(func() error { return errors.New("fail") }) + } + assert.Equal(t, platform.StateOpen, cb.GetState()) + + // Next call is blocked + err := cb.Execute(func() error { return nil }) + assert.ErrorIs(t, err, platform.ErrCircuitOpen) +} + +func TestCircuitBreaker_GoesHalfOpenAfterTimeout(t *testing.T) { + t.Parallel() + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "test", + Threshold: 2, + ResetTimeout: 50 * time.Millisecond, + OnStateChange: nil, + }) + + _ = cb.Execute(func() error { return errors.New("fail") }) + _ = cb.Execute(func() error { return errors.New("fail") }) + assert.Equal(t, platform.StateOpen, cb.GetState()) + + // Wait for half-open window to expire, then trigger a probe request + time.Sleep(80 * time.Millisecond) + _ = cb.Execute(func() error { return errors.New("still failing") }) + // After ResetTimeout the CB transitions to half-open automatically. + // The probe arrives during or just after that transition, so either + // Open (transition not yet observed) or HalfOpen (transition complete but probe pending) + // is valid — this is not a flaky test. + assert.True(t, cb.GetState() == platform.StateOpen || cb.GetState() == platform.StateHalfOpen) +} + +func TestCircuitBreaker_ClosesAfterSuccessfulProbe(t *testing.T) { + t.Parallel() + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "test", + Threshold: 2, + ResetTimeout: 50 * time.Millisecond, + OnStateChange: nil, + }) + + _ = cb.Execute(func() error { return errors.New("fail") }) + _ = cb.Execute(func() error { return errors.New("fail") }) + time.Sleep(80 * time.Millisecond) + _ = cb.Execute(func() error { return nil }) + + assert.Equal(t, platform.StateClosed, cb.GetState()) +} diff --git a/internal/core/services/gateway_unit_test.go b/internal/core/services/gateway_unit_test.go index 31f823f3d..8f3c068ed 100644 --- a/internal/core/services/gateway_unit_test.go +++ b/internal/core/services/gateway_unit_test.go @@ -16,6 +16,13 @@ import ( "github.com/stretchr/testify/require" ) +const ( + defaultCircuitBreakerThreshold = 5 + defaultCircuitBreakerTimeout = 30000 // ms + defaultMaxRetries = 2 + defaultRetryTimeout = 5000 // ms +) + type mockGatewayRepo struct { mock.Mock } @@ -67,6 +74,19 @@ func TestGatewayService_Unit(t *testing.T) { ctx := appcontext.WithUserID(context.Background(), uuid.New()) userID := appcontext.UserIDFromContext(ctx) + t.Run("CreateRoute applies default resilience values", func(t *testing.T) { + params := ports.CreateRouteParams{Name: "r1", Pattern: "/r1", Target: "http://t1"} + repo.On("CreateRoute", ctx, mock.Anything).Return(nil).Once() + auditSvc.On("Log", mock.Anything, userID, "gateway.route_create", "gateway", mock.Anything, mock.Anything).Return(nil).Once() + + res, err := svc.CreateRoute(ctx, params) + require.NoError(t, err) + assert.Equal(t, defaultCircuitBreakerThreshold, res.CircuitBreakerThreshold) + assert.Equal(t, int64(defaultCircuitBreakerTimeout), res.CircuitBreakerTimeout) + assert.Equal(t, defaultMaxRetries, res.MaxRetries) + assert.Equal(t, int64(defaultRetryTimeout), res.RetryTimeout) + }) + t.Run("CreateRoute", func(t *testing.T) { params := ports.CreateRouteParams{Name: "r1", Pattern: "/r1", Target: "http://t1"} repo.On("CreateRoute", ctx, mock.Anything).Return(nil).Once() diff --git a/internal/handlers/gateway_handler.go b/internal/handlers/gateway_handler.go index 9e7e8ab28..ed04a8776 100644 --- a/internal/handlers/gateway_handler.go +++ b/internal/handlers/gateway_handler.go @@ -21,21 +21,25 @@ import ( // CreateRouteRequest define the payload for creating a route. type CreateRouteRequest struct { - Name string `json:"name" binding:"required"` - PathPrefix string `json:"path_prefix" binding:"required"` - TargetURL string `json:"target_url" binding:"required"` - Methods []string `json:"methods"` - StripPrefix bool `json:"strip_prefix"` - RateLimit int `json:"rate_limit" binding:"gte=0"` - DialTimeout int64 `json:"dial_timeout" binding:"gte=0"` + Name string `json:"name" binding:"required"` + PathPrefix string `json:"path_prefix" binding:"required"` + TargetURL string `json:"target_url" binding:"required"` + Methods []string `json:"methods"` + StripPrefix bool `json:"strip_prefix"` + RateLimit int `json:"rate_limit" binding:"gte=0"` + DialTimeout int64 `json:"dial_timeout" binding:"gte=0"` ResponseHeaderTimeout int64 `json:"response_header_timeout" binding:"gte=0"` - IdleConnTimeout int64 `json:"idle_conn_timeout" binding:"gte=0"` - TLSSkipVerify bool `json:"tls_skip_verify"` - RequireTLS bool `json:"require_tls"` - AllowedCIDRs []string `json:"allowed_cidrs"` - BlockedCIDRs []string `json:"blocked_cidrs"` - MaxBodySize int64 `json:"max_body_size" binding:"gte=0"` - Priority int `json:"priority" binding:"gte=0"` + IdleConnTimeout int64 `json:"idle_conn_timeout" binding:"gte=0"` + TLSSkipVerify bool `json:"tls_skip_verify"` + RequireTLS bool `json:"require_tls"` + AllowedCIDRs []string `json:"allowed_cidrs"` + BlockedCIDRs []string `json:"blocked_cidrs"` + MaxBodySize int64 `json:"max_body_size" binding:"gte=0"` + Priority int `json:"priority" binding:"gte=0"` + CircuitBreakerThreshold int `json:"circuit_breaker_threshold" binding:"gte=0"` + CircuitBreakerTimeout int64 `json:"circuit_breaker_timeout" binding:"gte=0"` + MaxRetries int `json:"max_retries" binding:"gte=0"` + RetryTimeout int64 `json:"retry_timeout" binding:"gte=0"` } // GatewayHandler handles API gateway HTTP endpoints. @@ -80,21 +84,25 @@ func (h *GatewayHandler) CreateRoute(c *gin.Context) { } params := ports.CreateRouteParams{ - Name: req.Name, - Pattern: req.PathPrefix, - Target: req.TargetURL, - Methods: req.Methods, - StripPrefix: req.StripPrefix, - RateLimit: req.RateLimit, - DialTimeout: req.DialTimeout, - ResponseHeaderTimeout: req.ResponseHeaderTimeout, - IdleConnTimeout: req.IdleConnTimeout, - TLSSkipVerify: req.TLSSkipVerify, - RequireTLS: req.RequireTLS, - AllowedCIDRs: req.AllowedCIDRs, - BlockedCIDRs: req.BlockedCIDRs, - MaxBodySize: req.MaxBodySize, - Priority: req.Priority, + Name: req.Name, + Pattern: req.PathPrefix, + Target: req.TargetURL, + Methods: req.Methods, + StripPrefix: req.StripPrefix, + RateLimit: req.RateLimit, + DialTimeout: req.DialTimeout, + ResponseHeaderTimeout: req.ResponseHeaderTimeout, + IdleConnTimeout: req.IdleConnTimeout, + TLSSkipVerify: req.TLSSkipVerify, + RequireTLS: req.RequireTLS, + AllowedCIDRs: req.AllowedCIDRs, + BlockedCIDRs: req.BlockedCIDRs, + MaxBodySize: req.MaxBodySize, + Priority: req.Priority, + CircuitBreakerThreshold: req.CircuitBreakerThreshold, + CircuitBreakerTimeout: req.CircuitBreakerTimeout, + MaxRetries: req.MaxRetries, + RetryTimeout: req.RetryTimeout, } route, err := h.svc.CreateRoute(c.Request.Context(), params)