diff --git a/backend/src/app.module.ts b/backend/src/app.module.ts index 2ed7fd65b..8a8db8fcc 100644 --- a/backend/src/app.module.ts +++ b/backend/src/app.module.ts @@ -1,9 +1,11 @@ import { Module } from '@nestjs/common'; import { ConfigModule, ConfigService } from '@nestjs/config'; -import { ThrottlerGuard, ThrottlerModule } from '@nestjs/throttler'; +import { ThrottlerModule } from '@nestjs/throttler'; import { APP_GUARD, APP_INTERCEPTOR } from '@nestjs/core'; import { CorrelationIdInterceptor } from './common/interceptors/correlation-id.interceptor'; import { AuditLogInterceptor } from './common/interceptors/audit-log.interceptor'; +import { TieredThrottlerGuard } from './common/guards/tiered-throttler.guard'; +import { CommonModule } from './common/common.module'; import { EventEmitterModule } from '@nestjs/event-emitter'; import { LoggerModule } from 'nestjs-pino'; import * as Joi from 'joi'; @@ -29,7 +31,6 @@ import { NotificationsModule } from './modules/notifications/notifications.modul import { TransactionsModule } from './modules/transactions/transactions.module'; import { TestRbacModule } from './test-rbac/test-rbac.module'; import { TestThrottlingModule } from './test-throttling/test-throttling.module'; -import { CustomThrottlerGuard } from './common/guards/custom-throttler.guard'; const envValidationSchema = Joi.object({ NODE_ENV: Joi.string().valid('development', 'production', 'test').required(), @@ -165,6 +166,7 @@ const envValidationSchema = Joi.object({ TransactionsModule, TestRbacModule, TestThrottlingModule, + CommonModule, ThrottlerModule.forRoot([ { name: 'default', @@ -188,7 +190,7 @@ const envValidationSchema = Joi.object({ AppService, { provide: APP_GUARD, - useClass: ThrottlerGuard, + useClass: TieredThrottlerGuard, }, { provide: APP_INTERCEPTOR, diff --git a/backend/src/auth/auth.service.ts b/backend/src/auth/auth.service.ts index c26b016f9..d76e8542e 100644 --- a/backend/src/auth/auth.service.ts +++ b/backend/src/auth/auth.service.ts @@ -51,7 +51,12 @@ export class AuthService { } return { - accessToken: this.generateToken(user.id, user.email, user.role), + accessToken: this.generateToken( + user.id, + user.email, + user.role, + user.kycStatus, + ), }; } @@ -64,8 +69,13 @@ export class AuthService { return null; } - private generateToken(userId: string, email: string, role = 'USER') { - return this.jwtService.sign({ sub: userId, email, role }); + private generateToken( + userId: string, + email: string, + role = 'USER', + kycStatus = 'NOT_SUBMITTED', + ) { + return this.jwtService.sign({ sub: userId, email, role, kycStatus }); } async generateNonce(publicKey: string): Promise<{ nonce: string }> { diff --git a/backend/src/auth/strategies/jwt.strategy.ts b/backend/src/auth/strategies/jwt.strategy.ts index 5ec26e10c..3698eaec6 100644 --- a/backend/src/auth/strategies/jwt.strategy.ts +++ b/backend/src/auth/strategies/jwt.strategy.ts @@ -18,11 +18,17 @@ export class JwtStrategy extends PassportStrategy(Strategy) { }); } - async validate(payload: { sub: string; email: string; role?: string }) { + async validate(payload: { + sub: string; + email: string; + role?: string; + kycStatus?: string; + }) { return { id: payload.sub, email: payload.email, role: payload.role ?? 'USER', + kycStatus: payload.kycStatus ?? 'NOT_SUBMITTED', }; } } diff --git a/backend/src/common/common.module.ts b/backend/src/common/common.module.ts new file mode 100644 index 000000000..679958000 --- /dev/null +++ b/backend/src/common/common.module.ts @@ -0,0 +1,9 @@ +import { Global, Module } from '@nestjs/common'; +import { RateLimitMonitorService } from './services/rate-limit-monitor.service'; + +@Global() +@Module({ + providers: [RateLimitMonitorService], + exports: [RateLimitMonitorService], +}) +export class CommonModule {} diff --git a/backend/src/common/guards/tiered-throttler.guard.ts b/backend/src/common/guards/tiered-throttler.guard.ts new file mode 100644 index 000000000..d715c792b --- /dev/null +++ b/backend/src/common/guards/tiered-throttler.guard.ts @@ -0,0 +1,193 @@ +import { + Injectable, + ExecutionContext, + Inject, + Logger, + Optional, +} from '@nestjs/common'; +import { ThrottlerGuard, ThrottlerException } from '@nestjs/throttler'; +import { Request, Response } from 'express'; +import { RateLimitMonitorService } from '../services/rate-limit-monitor.service'; + +/** + * User tiers for rate limiting. + * Tier is derived from user role + KYC status. + */ +export enum UserTier { + FREE = 'free', + VERIFIED = 'verified', // KYC approved + PREMIUM = 'premium', // Future: paid plan + ENTERPRISE = 'enterprise', // Future: enterprise plan + ADMIN = 'admin', +} + +/** + * Rate limit configuration per tier. + * Each tier defines limits for each named throttler. + */ +const TIER_LIMITS: Record< + UserTier, + Record +> = { + [UserTier.FREE]: { + default: { limit: 60, ttl: 60000 }, + auth: { limit: 5, ttl: 15 * 60 * 1000 }, + rpc: { limit: 5, ttl: 60000 }, + }, + [UserTier.VERIFIED]: { + default: { limit: 150, ttl: 60000 }, + auth: { limit: 10, ttl: 15 * 60 * 1000 }, + rpc: { limit: 15, ttl: 60000 }, + }, + [UserTier.PREMIUM]: { + default: { limit: 300, ttl: 60000 }, + auth: { limit: 15, ttl: 15 * 60 * 1000 }, + rpc: { limit: 30, ttl: 60000 }, + }, + [UserTier.ENTERPRISE]: { + default: { limit: 1000, ttl: 60000 }, + auth: { limit: 30, ttl: 15 * 60 * 1000 }, + rpc: { limit: 100, ttl: 60000 }, + }, + [UserTier.ADMIN]: { + default: { limit: 1000, ttl: 60000 }, + auth: { limit: 50, ttl: 15 * 60 * 1000 }, + rpc: { limit: 100, ttl: 60000 }, + }, +}; + +/** + * TieredThrottlerGuard - Rate limiting based on user tier. + * + * Determines user tier from JWT payload (role + kycStatus) + * and applies appropriate rate limits. Injects standard + * rate limit headers into every response. + */ +@Injectable() +export class TieredThrottlerGuard extends ThrottlerGuard { + private readonly logger = new Logger(TieredThrottlerGuard.name); + + @Optional() + @Inject(RateLimitMonitorService) + private readonly monitorService?: RateLimitMonitorService; + + /** + * Resolve the user's tier from the request context. + */ + static resolveUserTier(user?: { + role?: string; + kycStatus?: string; + tier?: string; + }): UserTier { + if (!user) return UserTier.FREE; + + // Explicit tier override (for future paid plans) + if (user.tier === 'enterprise') return UserTier.ENTERPRISE; + if (user.tier === 'premium') return UserTier.PREMIUM; + + // Admin always gets highest limits + if (user.role === 'ADMIN') return UserTier.ADMIN; + + // KYC-verified users get higher limits + if (user.kycStatus === 'APPROVED') return UserTier.VERIFIED; + + return UserTier.FREE; + } + + /** + * Get the rate limit config for a user tier and throttler name. + */ + static getLimitsForTier( + tier: UserTier, + throttlerName: string, + ): { limit: number; ttl: number } { + const tierConfig = TIER_LIMITS[tier]; + return tierConfig[throttlerName] || tierConfig['default']; + } + + protected async getTracker(req: Record): Promise { + const user = req.user; + if (user?.id) { + return `tiered-throttle:${user.id}`; + } + const ip = req.ip || req.connection?.remoteAddress || 'unknown'; + return `tiered-throttle:${ip}`; + } + + protected async handleRequest( + requestProps: { + context: ExecutionContext; + limit: number; + ttl: number; + throttler: { name: string; limit: number; ttl: number }; + blockDuration: number; + getTracker: (req: Record) => Promise; + generateKey: ( + context: ExecutionContext, + tracker: string, + throttlerName: string, + ) => string; + }, + ): Promise { + const { context, throttler } = requestProps; + const request = context.switchToHttp().getRequest(); + const response = context.switchToHttp().getResponse(); + const user = (request as any).user; + + const tier = TieredThrottlerGuard.resolveUserTier(user); + const tierLimits = TieredThrottlerGuard.getLimitsForTier( + tier, + throttler.name, + ); + + // Override the limit and ttl with tier-based values + requestProps.limit = tierLimits.limit; + requestProps.ttl = tierLimits.ttl; + + // Set rate limit headers on every response + response.setHeader('X-RateLimit-Limit', tierLimits.limit); + response.setHeader('X-RateLimit-Tier', tier); + + try { + const result = await super.handleRequest(requestProps); + return result; + } catch (error) { + if (error instanceof ThrottlerException) { + this.logger.warn( + `[Rate Limit] Tier: ${tier} | User: ${user?.id || 'anon'} | ` + + `Route: ${request.method} ${request.path} | ` + + `Throttler: ${throttler.name} | ` + + `Limit: ${tierLimits.limit}/${Math.round(tierLimits.ttl / 1000)}s`, + ); + + this.monitorService?.recordViolation({ + userId: user?.id || null, + ip: request.ip || 'unknown', + tier, + route: request.path, + method: request.method, + throttlerName: throttler.name, + limit: tierLimits.limit, + ttl: tierLimits.ttl, + timestamp: new Date(), + }); + + response.setHeader( + 'Retry-After', + Math.ceil(tierLimits.ttl / 1000), + ); + response.setHeader('X-RateLimit-Remaining', 0); + response.setHeader( + 'X-RateLimit-Reset', + new Date(Date.now() + tierLimits.ttl).toISOString(), + ); + + throw new ThrottlerException( + `Rate limit exceeded for ${tier} tier. ` + + `Maximum ${tierLimits.limit} requests per ${Math.round(tierLimits.ttl / 1000)} seconds.`, + ); + } + throw error; + } + } +} diff --git a/backend/src/common/services/rate-limit-monitor.service.ts b/backend/src/common/services/rate-limit-monitor.service.ts new file mode 100644 index 000000000..5e33396b9 --- /dev/null +++ b/backend/src/common/services/rate-limit-monitor.service.ts @@ -0,0 +1,94 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { OnEvent } from '@nestjs/event-emitter'; +import { EventEmitter2 } from '@nestjs/event-emitter'; + +export interface RateLimitViolation { + userId: string | null; + ip: string; + tier: string; + route: string; + method: string; + throttlerName: string; + limit: number; + ttl: number; + timestamp: Date; +} + +/** + * In-memory rate limit violation tracker for admin monitoring. + * Stores the last 1000 violations in a circular buffer. + */ +@Injectable() +export class RateLimitMonitorService { + private readonly logger = new Logger(RateLimitMonitorService.name); + private readonly violations: RateLimitViolation[] = []; + private readonly MAX_VIOLATIONS = 1000; + + constructor(private readonly eventEmitter: EventEmitter2) {} + + recordViolation(violation: RateLimitViolation): void { + if (this.violations.length >= this.MAX_VIOLATIONS) { + this.violations.shift(); + } + this.violations.push(violation); + this.eventEmitter.emit('ratelimit.violation', violation); + } + + getRecentViolations(limit = 50): RateLimitViolation[] { + return this.violations.slice(-limit).reverse(); + } + + getViolationsByUser(userId: string, limit = 50): RateLimitViolation[] { + return this.violations + .filter((v) => v.userId === userId) + .slice(-limit) + .reverse(); + } + + getViolationSummary(): { + total: number; + last24h: number; + topOffenders: { userId: string; count: number }[]; + byTier: Record; + byRoute: Record; + } { + const now = Date.now(); + const dayAgo = now - 24 * 60 * 60 * 1000; + + const last24h = this.violations.filter( + (v) => v.timestamp.getTime() >= dayAgo, + ); + + // Count by user + const userCounts: Record = {}; + for (const v of last24h) { + const key = v.userId || v.ip; + userCounts[key] = (userCounts[key] || 0) + 1; + } + const topOffenders = Object.entries(userCounts) + .map(([userId, count]) => ({ userId, count })) + .sort((a, b) => b.count - a.count) + .slice(0, 10); + + // Count by tier + const byTier: Record = {}; + for (const v of last24h) { + byTier[v.tier] = (byTier[v.tier] || 0) + 1; + } + + // Count by route + const byRoute: Record = {}; + for (const v of last24h) { + const key = `${v.method} ${v.route}`; + byRoute[key] = (byRoute[key] || 0) + 1; + } + + return { + total: this.violations.length, + last24h: last24h.length, + topOffenders, + byTier, + byRoute, + }; + } +} diff --git a/backend/src/modules/admin/admin.controller.ts b/backend/src/modules/admin/admin.controller.ts index 7976952ac..f836f2135 100644 --- a/backend/src/modules/admin/admin.controller.ts +++ b/backend/src/modules/admin/admin.controller.ts @@ -1,23 +1,32 @@ import { Controller, + Get, Patch, Param, Body, + Query, UseGuards, BadRequestException, } from '@nestjs/common'; +import { ApiTags, ApiBearerAuth, ApiOperation, ApiResponse } from '@nestjs/swagger'; import { UserService } from '../user/user.service'; import { JwtAuthGuard } from '../../auth/guards/jwt-auth.guard'; import { RolesGuard } from '../../common/guards/roles.guard'; import { Roles } from '../../common/decorators/roles.decorator'; import { Role } from '../../common/enums/role.enum'; +import { RateLimitMonitorService } from '../../common/services/rate-limit-monitor.service'; import { ApproveKycDto, RejectKycDto } from '../user/dto/update-user.dto'; +@ApiTags('admin') @Controller('admin') @UseGuards(JwtAuthGuard, RolesGuard) @Roles(Role.ADMIN) +@ApiBearerAuth() export class AdminController { - constructor(private readonly userService: UserService) {} + constructor( + private readonly userService: UserService, + private readonly rateLimitMonitor: RateLimitMonitorService, + ) {} @Patch('users/:id/kyc/approve') async approveKyc(@Param('id') userId: string) { @@ -60,4 +69,33 @@ export class AdminController { ); } } + + @Get('rate-limits/summary') + @ApiOperation({ summary: 'Get rate limit violation summary' }) + @ApiResponse({ status: 200, description: 'Rate limit violation summary' }) + getRateLimitSummary() { + return this.rateLimitMonitor.getViolationSummary(); + } + + @Get('rate-limits/violations') + @ApiOperation({ summary: 'Get recent rate limit violations' }) + @ApiResponse({ status: 200, description: 'Recent rate limit violations' }) + getRecentViolations(@Query('limit') limit?: string) { + return this.rateLimitMonitor.getRecentViolations( + limit ? parseInt(limit, 10) : 50, + ); + } + + @Get('rate-limits/violations/:userId') + @ApiOperation({ summary: 'Get rate limit violations for a specific user' }) + @ApiResponse({ status: 200, description: 'User rate limit violations' }) + getUserViolations( + @Param('userId') userId: string, + @Query('limit') limit?: string, + ) { + return this.rateLimitMonitor.getViolationsByUser( + userId, + limit ? parseInt(limit, 10) : 50, + ); + } }