Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions backend/src/app.module.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { Module } from '@nestjs/common';
import { ConfigModule, ConfigService } from '@nestjs/config';
import { ThrottlerGuard, ThrottlerModule } from '@nestjs/throttler';
import { ThrottlerModule } from '@nestjs/throttler';
import { APP_GUARD, APP_INTERCEPTOR } from '@nestjs/core';
import { CorrelationIdInterceptor } from './common/interceptors/correlation-id.interceptor';
import { AuditLogInterceptor } from './common/interceptors/audit-log.interceptor';
import { TieredThrottlerGuard } from './common/guards/tiered-throttler.guard';
import { CommonModule } from './common/common.module';
import { EventEmitterModule } from '@nestjs/event-emitter';
import { LoggerModule } from 'nestjs-pino';
import * as Joi from 'joi';
Expand All @@ -29,7 +31,6 @@ import { NotificationsModule } from './modules/notifications/notifications.modul
import { TransactionsModule } from './modules/transactions/transactions.module';
import { TestRbacModule } from './test-rbac/test-rbac.module';
import { TestThrottlingModule } from './test-throttling/test-throttling.module';
import { CustomThrottlerGuard } from './common/guards/custom-throttler.guard';

const envValidationSchema = Joi.object({
NODE_ENV: Joi.string().valid('development', 'production', 'test').required(),
Expand Down Expand Up @@ -165,6 +166,7 @@ const envValidationSchema = Joi.object({
TransactionsModule,
TestRbacModule,
TestThrottlingModule,
CommonModule,
ThrottlerModule.forRoot([
{
name: 'default',
Expand All @@ -188,7 +190,7 @@ const envValidationSchema = Joi.object({
AppService,
{
provide: APP_GUARD,
useClass: ThrottlerGuard,
useClass: TieredThrottlerGuard,
},
{
provide: APP_INTERCEPTOR,
Expand Down
16 changes: 13 additions & 3 deletions backend/src/auth/auth.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ export class AuthService {
}

return {
accessToken: this.generateToken(user.id, user.email, user.role),
accessToken: this.generateToken(
user.id,
user.email,
user.role,
user.kycStatus,
),
};
}

Expand All @@ -64,8 +69,13 @@ export class AuthService {
return null;
}

private generateToken(userId: string, email: string, role = 'USER') {
return this.jwtService.sign({ sub: userId, email, role });
private generateToken(
userId: string,
email: string,
role = 'USER',
kycStatus = 'NOT_SUBMITTED',
) {
return this.jwtService.sign({ sub: userId, email, role, kycStatus });
}

async generateNonce(publicKey: string): Promise<{ nonce: string }> {
Expand Down
8 changes: 7 additions & 1 deletion backend/src/auth/strategies/jwt.strategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@ export class JwtStrategy extends PassportStrategy(Strategy) {
});
}

async validate(payload: { sub: string; email: string; role?: string }) {
async validate(payload: {
sub: string;
email: string;
role?: string;
kycStatus?: string;
}) {
return {
id: payload.sub,
email: payload.email,
role: payload.role ?? 'USER',
kycStatus: payload.kycStatus ?? 'NOT_SUBMITTED',
};
}
}
9 changes: 9 additions & 0 deletions backend/src/common/common.module.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { Global, Module } from '@nestjs/common';
import { RateLimitMonitorService } from './services/rate-limit-monitor.service';

@Global()
@Module({
providers: [RateLimitMonitorService],
exports: [RateLimitMonitorService],
})
export class CommonModule {}
193 changes: 193 additions & 0 deletions backend/src/common/guards/tiered-throttler.guard.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import {
Injectable,
ExecutionContext,
Inject,
Logger,
Optional,
} from '@nestjs/common';
import { ThrottlerGuard, ThrottlerException } from '@nestjs/throttler';
import { Request, Response } from 'express';
import { RateLimitMonitorService } from '../services/rate-limit-monitor.service';

/**
* User tiers for rate limiting.
* Tier is derived from user role + KYC status.
*/
export enum UserTier {
FREE = 'free',
VERIFIED = 'verified', // KYC approved
PREMIUM = 'premium', // Future: paid plan
ENTERPRISE = 'enterprise', // Future: enterprise plan
ADMIN = 'admin',
}

/**
* Rate limit configuration per tier.
* Each tier defines limits for each named throttler.
*/
const TIER_LIMITS: Record<
UserTier,
Record<string, { limit: number; ttl: number }>
> = {
[UserTier.FREE]: {
default: { limit: 60, ttl: 60000 },
auth: { limit: 5, ttl: 15 * 60 * 1000 },
rpc: { limit: 5, ttl: 60000 },
},
[UserTier.VERIFIED]: {
default: { limit: 150, ttl: 60000 },
auth: { limit: 10, ttl: 15 * 60 * 1000 },
rpc: { limit: 15, ttl: 60000 },
},
[UserTier.PREMIUM]: {
default: { limit: 300, ttl: 60000 },
auth: { limit: 15, ttl: 15 * 60 * 1000 },
rpc: { limit: 30, ttl: 60000 },
},
[UserTier.ENTERPRISE]: {
default: { limit: 1000, ttl: 60000 },
auth: { limit: 30, ttl: 15 * 60 * 1000 },
rpc: { limit: 100, ttl: 60000 },
},
[UserTier.ADMIN]: {
default: { limit: 1000, ttl: 60000 },
auth: { limit: 50, ttl: 15 * 60 * 1000 },
rpc: { limit: 100, ttl: 60000 },
},
};

/**
* TieredThrottlerGuard - Rate limiting based on user tier.
*
* Determines user tier from JWT payload (role + kycStatus)
* and applies appropriate rate limits. Injects standard
* rate limit headers into every response.
*/
@Injectable()
export class TieredThrottlerGuard extends ThrottlerGuard {
private readonly logger = new Logger(TieredThrottlerGuard.name);

@Optional()
@Inject(RateLimitMonitorService)
private readonly monitorService?: RateLimitMonitorService;

/**
* Resolve the user's tier from the request context.
*/
static resolveUserTier(user?: {
role?: string;
kycStatus?: string;
tier?: string;
}): UserTier {
if (!user) return UserTier.FREE;

// Explicit tier override (for future paid plans)
if (user.tier === 'enterprise') return UserTier.ENTERPRISE;
if (user.tier === 'premium') return UserTier.PREMIUM;

// Admin always gets highest limits
if (user.role === 'ADMIN') return UserTier.ADMIN;

// KYC-verified users get higher limits
if (user.kycStatus === 'APPROVED') return UserTier.VERIFIED;

return UserTier.FREE;
}

/**
* Get the rate limit config for a user tier and throttler name.
*/
static getLimitsForTier(
tier: UserTier,
throttlerName: string,
): { limit: number; ttl: number } {
const tierConfig = TIER_LIMITS[tier];
return tierConfig[throttlerName] || tierConfig['default'];
}

protected async getTracker(req: Record<string, any>): Promise<string> {
const user = req.user;
if (user?.id) {
return `tiered-throttle:${user.id}`;
}
const ip = req.ip || req.connection?.remoteAddress || 'unknown';
return `tiered-throttle:${ip}`;
}

protected async handleRequest(
requestProps: {
context: ExecutionContext;
limit: number;
ttl: number;
throttler: { name: string; limit: number; ttl: number };
blockDuration: number;
getTracker: (req: Record<string, any>) => Promise<string>;
generateKey: (
context: ExecutionContext,
tracker: string,
throttlerName: string,
) => string;
},
): Promise<boolean> {
const { context, throttler } = requestProps;
const request = context.switchToHttp().getRequest<Request>();
const response = context.switchToHttp().getResponse<Response>();
const user = (request as any).user;

const tier = TieredThrottlerGuard.resolveUserTier(user);
const tierLimits = TieredThrottlerGuard.getLimitsForTier(
tier,
throttler.name,
);

// Override the limit and ttl with tier-based values
requestProps.limit = tierLimits.limit;
requestProps.ttl = tierLimits.ttl;

// Set rate limit headers on every response
response.setHeader('X-RateLimit-Limit', tierLimits.limit);
response.setHeader('X-RateLimit-Tier', tier);

try {
const result = await super.handleRequest(requestProps);
return result;
} catch (error) {
if (error instanceof ThrottlerException) {
this.logger.warn(
`[Rate Limit] Tier: ${tier} | User: ${user?.id || 'anon'} | ` +
`Route: ${request.method} ${request.path} | ` +
`Throttler: ${throttler.name} | ` +
`Limit: ${tierLimits.limit}/${Math.round(tierLimits.ttl / 1000)}s`,
);

this.monitorService?.recordViolation({
userId: user?.id || null,
ip: request.ip || 'unknown',
tier,
route: request.path,
method: request.method,
throttlerName: throttler.name,
limit: tierLimits.limit,
ttl: tierLimits.ttl,
timestamp: new Date(),
});

response.setHeader(
'Retry-After',
Math.ceil(tierLimits.ttl / 1000),
);
response.setHeader('X-RateLimit-Remaining', 0);
response.setHeader(
'X-RateLimit-Reset',
new Date(Date.now() + tierLimits.ttl).toISOString(),
);

throw new ThrottlerException(
`Rate limit exceeded for ${tier} tier. ` +
`Maximum ${tierLimits.limit} requests per ${Math.round(tierLimits.ttl / 1000)} seconds.`,
);
}
throw error;
}
}
}
94 changes: 94 additions & 0 deletions backend/src/common/services/rate-limit-monitor.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import { Injectable, Logger } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import { EventEmitter2 } from '@nestjs/event-emitter';

export interface RateLimitViolation {
userId: string | null;
ip: string;
tier: string;
route: string;
method: string;
throttlerName: string;
limit: number;
ttl: number;
timestamp: Date;
}

/**
* In-memory rate limit violation tracker for admin monitoring.
* Stores the last 1000 violations in a circular buffer.
*/
@Injectable()
export class RateLimitMonitorService {
private readonly logger = new Logger(RateLimitMonitorService.name);
private readonly violations: RateLimitViolation[] = [];
private readonly MAX_VIOLATIONS = 1000;

constructor(private readonly eventEmitter: EventEmitter2) {}

recordViolation(violation: RateLimitViolation): void {
if (this.violations.length >= this.MAX_VIOLATIONS) {
this.violations.shift();
}
this.violations.push(violation);
this.eventEmitter.emit('ratelimit.violation', violation);
}

getRecentViolations(limit = 50): RateLimitViolation[] {
return this.violations.slice(-limit).reverse();
}

getViolationsByUser(userId: string, limit = 50): RateLimitViolation[] {
return this.violations
.filter((v) => v.userId === userId)
.slice(-limit)
.reverse();
}

getViolationSummary(): {
total: number;
last24h: number;
topOffenders: { userId: string; count: number }[];
byTier: Record<string, number>;
byRoute: Record<string, number>;
} {
const now = Date.now();
const dayAgo = now - 24 * 60 * 60 * 1000;

const last24h = this.violations.filter(
(v) => v.timestamp.getTime() >= dayAgo,
);

// Count by user
const userCounts: Record<string, number> = {};
for (const v of last24h) {
const key = v.userId || v.ip;
userCounts[key] = (userCounts[key] || 0) + 1;
}
const topOffenders = Object.entries(userCounts)
.map(([userId, count]) => ({ userId, count }))
.sort((a, b) => b.count - a.count)
.slice(0, 10);

// Count by tier
const byTier: Record<string, number> = {};
for (const v of last24h) {
byTier[v.tier] = (byTier[v.tier] || 0) + 1;
}

// Count by route
const byRoute: Record<string, number> = {};
for (const v of last24h) {
const key = `${v.method} ${v.route}`;
byRoute[key] = (byRoute[key] || 0) + 1;
}

return {
total: this.violations.length,
last24h: last24h.length,
topOffenders,
byTier,
byRoute,
};
}
}
Loading
Loading