diff --git a/middleware/src/middleware/advanced/circuit-breaker.middleware.ts b/middleware/src/middleware/advanced/circuit-breaker.middleware.ts index 8891dfb..a47454d 100644 --- a/middleware/src/middleware/advanced/circuit-breaker.middleware.ts +++ b/middleware/src/middleware/advanced/circuit-breaker.middleware.ts @@ -8,86 +8,99 @@ export enum CircuitState { } export interface CircuitBreakerOptions { - /** Number of consecutive failures before opening the circuit. Default: 5 */ + /** Number of failures before opening the circuit. Default: 5 */ failureThreshold?: number; + /** Window in ms for counting failures. Default: 60000 (1 minute) */ + timeoutWindow?: number; /** Time in ms to wait before moving from OPEN to HALF_OPEN. Default: 30000 */ - resetTimeout?: number; + halfOpenRetryInterval?: number; /** HTTP status codes considered failures. Default: [500, 502, 503, 504] */ failureStatusCodes?: number[]; } /** * Tracks circuit breaker state and exposes it for health checks. - * - * State machine: - * CLOSED → (N failures) → OPEN - * OPEN → (resetTimeout elapsed) → HALF_OPEN - * HALF_OPEN → (success) → CLOSED | (failure) → OPEN */ @Injectable() export class CircuitBreakerService { private readonly logger = new Logger('CircuitBreakerService'); private state: CircuitState = CircuitState.CLOSED; - private failureCount = 0; + private failureTimestamps: number[] = []; private lastFailureTime: number | null = null; readonly failureThreshold: number; - readonly resetTimeout: number; + readonly timeoutWindow: number; + readonly halfOpenRetryInterval: number; readonly failureStatusCodes: number[]; constructor(options: CircuitBreakerOptions = {}) { this.failureThreshold = options.failureThreshold ?? 5; - this.resetTimeout = options.resetTimeout ?? 30_000; - this.failureStatusCodes = options.failureStatusCodes ?? [500, 502, 503, 504]; + this.timeoutWindow = options.timeoutWindow ?? 60_000; + this.halfOpenRetryInterval = options.halfOpenRetryInterval ?? 30_000; + this.failureStatusCodes = options.failureStatusCodes ?? [ + 500, 502, 503, 504, + ]; } getState(): CircuitState { + const now = Date.now(); + if ( this.state === CircuitState.OPEN && this.lastFailureTime !== null && - Date.now() - this.lastFailureTime >= this.resetTimeout + now - this.lastFailureTime >= this.halfOpenRetryInterval ) { this.logger.log('Circuit transitioning OPEN → HALF_OPEN'); this.state = CircuitState.HALF_OPEN; } + return this.state; } recordSuccess(): void { if (this.state === CircuitState.HALF_OPEN) { this.logger.log('Circuit transitioning HALF_OPEN → CLOSED'); + this.state = CircuitState.CLOSED; + this.failureTimestamps = []; + this.lastFailureTime = null; } - this.state = CircuitState.CLOSED; - this.failureCount = 0; - this.lastFailureTime = null; } recordFailure(): void { - this.failureCount++; - this.lastFailureTime = Date.now(); + const now = Date.now(); + this.lastFailureTime = now; - if ( - this.state === CircuitState.HALF_OPEN || - this.failureCount >= this.failureThreshold - ) { + if (this.state === CircuitState.HALF_OPEN) { + this.logger.warn('Circuit transitioning HALF_OPEN → OPEN'); + this.state = CircuitState.OPEN; + return; + } + + this.failureTimestamps.push(now); + + // Filter failures outside the window + this.failureTimestamps = this.failureTimestamps.filter( + (t) => now - t <= this.timeoutWindow, + ); + + if (this.failureTimestamps.length >= this.failureThreshold) { this.logger.warn( - `Circuit transitioning → OPEN (failures: ${this.failureCount})`, + `Circuit transitioning → OPEN (failures: ${this.failureTimestamps.length})`, ); this.state = CircuitState.OPEN; } } - /** Reset to initial CLOSED state (useful for testing). */ reset(): void { this.state = CircuitState.CLOSED; - this.failureCount = 0; + this.failureTimestamps = []; this.lastFailureTime = null; } } /** - * Middleware that short-circuits requests when the circuit is OPEN, - * returning 503 immediately without hitting downstream handlers. + * Middleware that short-circuits requests when the circuit is OPEN. + * Returns 503 Service Unavailable immediately. */ @Injectable() export class CircuitBreakerMiddleware implements NestMiddleware { @@ -113,7 +126,7 @@ export class CircuitBreakerMiddleware implements NestMiddleware { res.send = (body?: any): Response => { if (this.circuitBreaker.failureStatusCodes.includes(res.statusCode)) { this.circuitBreaker.recordFailure(); - } else { + } else if (res.statusCode >= 200 && res.statusCode < 300) { this.circuitBreaker.recordSuccess(); } return originalSend(body); @@ -122,3 +135,4 @@ export class CircuitBreakerMiddleware implements NestMiddleware { next(); } } + diff --git a/middleware/src/middleware/advanced/timeout.middleware.ts b/middleware/src/middleware/advanced/timeout.middleware.ts index 90b1b00..14eb67a 100644 --- a/middleware/src/middleware/advanced/timeout.middleware.ts +++ b/middleware/src/middleware/advanced/timeout.middleware.ts @@ -1,4 +1,9 @@ -import { Injectable, NestMiddleware, Logger } from '@nestjs/common'; +import { + Injectable, + NestMiddleware, + Logger, + ServiceUnavailableException, +} from '@nestjs/common'; import { Request, Response, NextFunction } from 'express'; export interface TimeoutMiddlewareOptions { @@ -8,7 +13,8 @@ export interface TimeoutMiddlewareOptions { /** * Middleware that enforces a maximum request duration. - * Returns 503 Service Unavailable when the threshold is exceeded. + * Uses Promise.race() to reject after the configured threshold, + * letting NestJS's exception filter handle the 503 response. * * @example * consumer.apply(new TimeoutMiddleware({ timeout: 3000 }).use.bind(timeoutMiddleware)); @@ -22,23 +28,37 @@ export class TimeoutMiddleware implements NestMiddleware { this.timeout = options.timeout ?? 5000; } - use(req: Request, res: Response, next: NextFunction): void { - const timer = setTimeout(() => { - if (!res.headersSent) { + async use(req: Request, res: Response, next: NextFunction): Promise { + let timeoutId: NodeJS.Timeout; + + const timeoutPromise = new Promise((_, reject) => { + timeoutId = setTimeout(() => { this.logger.warn( `Request timed out after ${this.timeout}ms: ${req.method} ${req.path}`, ); - res.status(503).json({ - statusCode: 503, - message: `Request timed out after ${this.timeout}ms`, - error: 'Service Unavailable', - }); - } - }, this.timeout); + reject( + new ServiceUnavailableException( + `Request timed out after ${this.timeout}ms`, + ), + ); + }, this.timeout); + }); - res.on('finish', () => clearTimeout(timer)); - res.on('close', () => clearTimeout(timer)); + const nextPromise = new Promise((resolve) => { + res.on('finish', () => resolve(true)); + res.on('close', () => resolve(true)); + next(); + }); - next(); + try { + await Promise.race([nextPromise, timeoutPromise]); + } catch (error) { + if (!res.headersSent) { + next(error); + } + } finally { + clearTimeout(timeoutId!); + } } } + diff --git a/middleware/tests/unit/circuit-breaker.middleware.spec.ts b/middleware/tests/unit/circuit-breaker.middleware.spec.ts index 09a9dc9..cb27a8e 100644 --- a/middleware/tests/unit/circuit-breaker.middleware.spec.ts +++ b/middleware/tests/unit/circuit-breaker.middleware.spec.ts @@ -25,7 +25,8 @@ describe('CircuitBreakerService', () => { beforeEach(() => { svc = new CircuitBreakerService({ failureThreshold: 3, - resetTimeout: 5000, + timeoutWindow: 10000, + halfOpenRetryInterval: 5000, }); }); @@ -39,14 +40,23 @@ describe('CircuitBreakerService', () => { expect(svc.getState()).toBe(CircuitState.CLOSED); }); - it('transitions CLOSED → OPEN at failure threshold', () => { + it('transitions CLOSED → OPEN at failure threshold within window', () => { svc.recordFailure(); svc.recordFailure(); svc.recordFailure(); expect(svc.getState()).toBe(CircuitState.OPEN); }); - it('transitions OPEN → HALF_OPEN after resetTimeout', () => { + it('does not transition CLOSED → OPEN if failures are outside window', () => { + svc.recordFailure(); + svc.recordFailure(); + jest.advanceTimersByTime(10001); + svc.recordFailure(); + // One failure dropped, count is 1. One more added, count is 2. + expect(svc.getState()).toBe(CircuitState.CLOSED); + }); + + it('transitions OPEN → HALF_OPEN after halfOpenRetryInterval', () => { svc.recordFailure(); svc.recordFailure(); svc.recordFailure(); @@ -78,17 +88,7 @@ describe('CircuitBreakerService', () => { expect(svc.getState()).toBe(CircuitState.OPEN); }); - it('resets failure count on success', () => { - svc.recordFailure(); - svc.recordFailure(); - svc.recordSuccess(); - // Still 2 more failures before threshold of 3 - svc.recordFailure(); - svc.recordFailure(); - expect(svc.getState()).toBe(CircuitState.CLOSED); - }); - - it('reset() restores CLOSED state', () => { + it('resets state correctly with reset()', () => { svc.recordFailure(); svc.recordFailure(); svc.recordFailure(); @@ -145,3 +145,4 @@ describe('CircuitBreakerMiddleware', () => { expect(recordSuccess).toHaveBeenCalledTimes(1); }); }); +