diff --git a/build_manager/manager.py b/build_manager/manager.py index 315025f1..60bbc2ee 100644 --- a/build_manager/manager.py +++ b/build_manager/manager.py @@ -2,7 +2,6 @@ import redis import dill from enum import Enum -from utils import RateLimiter import logging import hashlib from metadata_manager import RemoteInfo @@ -136,14 +135,6 @@ def __init__(self, self.__task_queue = redis_task_queue_name self.__outdir = outdir - # Initialide an IP-based rate limiter. - # Allow 10 builds per hour per client - self.__ip_rate_limiter = RateLimiter( - redis_host=redis_host, - redis_port=redis_port, - time_window_sec=3600, - allowed_requests=10 - ) self.__build_entry_prefix = "buildmeta-" self.logger = logging.getLogger(__name__) self.logger.info( @@ -218,21 +209,17 @@ def __generate_build_id(self, build_info: BuildInfo) -> str: return bid def submit_build(self, - build_info: BuildInfo, - client_ip: str) -> str: + build_info: BuildInfo) -> str: """ Submit a new build request, generate a build ID, and queue the build for processing. Parameters: build_info (BuildInfo): The build information. - client_ip (str): The IP address of the client submitting the - build request. Returns: str: The generated build ID for the submitted build. """ - self.__ip_rate_limiter.count(client_ip) build_id = self.__generate_build_id(build_info) self.__insert_build_info(build_id=build_id, build_info=build_info) self.__queue_build(build_id=build_id) diff --git a/docker-compose.yml b/docker-compose.yml index 385df2cc..6e09c615 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -22,6 +22,7 @@ services: CBS_REMOTES_RELOAD_TOKEN: ${CBS_REMOTES_RELOAD_TOKEN} PYTHONPATH: /app CBS_BUILD_TIMEOUT_SEC: ${CBS_BUILD_TIMEOUT_SEC:-900} + FORWARDED_ALLOW_IPS: ${FORWARDED_ALLOW_IPS:-*} volumes: - ./base:/base:rw depends_on: diff --git a/utils/__init__.py b/utils/__init__.py index bdf51532..2bcc7e7e 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,8 +1,5 @@ from .taskrunner import TaskRunner -from .ratelimiter import RateLimiter, RateLimitExceededException __all__ = [ "TaskRunner", - "RateLimiter", - "RateLimitExceededException" ] diff --git a/utils/ratelimiter.py b/utils/ratelimiter.py deleted file mode 100644 index 19b492dc..00000000 --- a/utils/ratelimiter.py +++ /dev/null @@ -1,127 +0,0 @@ -import redis -import logging - - -class RateLimiter: - """ - A rate limiter that uses Redis as a backend to store request counts. - - This class allows you to limit the number of requests made by a client - (identified by a key) within a specified time window. - """ - def __init__(self, redis_host: str, redis_port: int, - time_window_sec: int, allowed_requests: int) -> None: - """ - Initialises the RateLimiter instance. - - Parameters: - redis_host (str): The Redis server hostname. - redis_port (int): The Redis server port. - time_window_sec (int): The time window (in seconds) in which - requests are counted. - allowed_requests (int): The maximum number of requests allowed - in the time window. - """ - self.__redis_client = redis.Redis( - host=redis_host, - port=redis_port, - decode_responses=True - ) - self.logger = logging.getLogger(__name__) - self.logger.info( - f"Redis connection established with {redis_host}:{redis_port}" - ) - - # Unique key prefix for this rate limiter instance - self.__key_prefix = f"rl-{id(self)}-" - self.__time_window_sec = time_window_sec - self.__allowed_requests = allowed_requests - - self.logger.info( - "RateLimiter initialized with parameters: " - f"Key prefix: {self.__key_prefix}, " - f"Time window: {self.__time_window_sec}s, " - f"Allowed requests per window: {self.__allowed_requests}" - ) - - def __del__(self) -> None: - """ - Clean up and close the Redis connection when the RateLimiter - is deleted. - """ - if self.__redis_client: - self.__redis_client.close() - self.logger.debug( - f"Redis connection closed for RateLimiter with id {id(self)}" - ) - - def __get_prefixed_key(self, key: str) -> str: - """ - Generates a unique key for Redis by adding a prefix. - - This helps avoid key collision in Redis with other data stored there. - - Parameters: - key (str): The key (e.g., client identifier) to be used for rate - limiting. - - Returns: - str: The Redis key with the instance-specific prefix. - """ - return self.__key_prefix + key - - def count(self, key: str) -> None: - """ - Increment the request count for a specific key (e.g., an IP address) - within the current time window. - - Parameters: - key (str): The key for which the request count is being updated. - For example, an IP address if rate limiting based on IPs. - - Raises: - RateLimitExceededException: If the number of requests exceeds the - allowed limit for the current time window. - """ - self.logger.debug(f"Counting a request for key: {key}") - pfx_key = self.__get_prefixed_key(key) - - # Check if the key already exists in Redis - if self.__redis_client.exists(pfx_key): - current_count = int(self.__redis_client.get(pfx_key)) - self.logger.debug( - f"Current request count for '{pfx_key}': {current_count}" - ) - - # If request count exceeds the allowed limit, raise exception - if current_count >= self.__allowed_requests: - self.logger.warning(f"Rate limit exceeded for key '{pfx_key}'") - raise RateLimitExceededException - - # Increment request count and keep TTL (time-to-live) unchanged - self.__redis_client.set( - name=pfx_key, - value=(current_count + 1), - keepttl=True - ) - else: - # Key doesn't exist yet, initialise count with TTL for time window - self.logger.debug( - f"No previous requests for key '{pfx_key}' in current window" - ", initialising count to 1" - ) - self.__redis_client.set( - name=pfx_key, - value=1, - ex=self.__time_window_sec - ) - - -class RateLimiterException(Exception): - pass - - -class RateLimitExceededException(RateLimiterException): - def __init__(self, *args): - message = "Too many requests. Try after some time." - super().__init__(message) diff --git a/web/api/v1/builds.py b/web/api/v1/builds.py index 107269bb..2e5d29bd 100644 --- a/web/api/v1/builds.py +++ b/web/api/v1/builds.py @@ -16,7 +16,7 @@ BuildOut, ) from services.builds import get_builds_service, BuildsService -from utils import RateLimitExceededException +from core.limiter import limiter router = APIRouter(prefix="/builds", tags=["builds"]) @@ -28,9 +28,19 @@ responses={ 400: {"description": "Invalid build configuration"}, 404: {"description": "Vehicle, board, or version not found"}, - 429: {"description": "Rate limit exceeded"} + 429: { + "description": "Rate limit exceeded", + "content": { + "application/json": { + "example": { + "detail": "Too many requests. Try again after some time." + } + } + } + } } ) +@limiter.limit("10/hour") async def create_build( build_request: BuildRequest, request: Request, @@ -52,19 +62,7 @@ async def create_build( 429: Rate limit exceeded """ try: - # Get client IP for rate limiting - forwarded_for = request.headers.get('X-Forwarded-For', None) - if forwarded_for: - client_ip = forwarded_for.split(',')[0].strip() - else: - client_ip = request.client.host if request.client else "unknown" - - return service.create_build(build_request, client_ip) - except RateLimitExceededException as e: - raise HTTPException( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - detail=str(e) - ) + return service.create_build(build_request) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: diff --git a/web/core/limiter.py b/web/core/limiter.py new file mode 100644 index 00000000..9c4ab98b --- /dev/null +++ b/web/core/limiter.py @@ -0,0 +1,35 @@ +import logging +from fastapi import Request +from fastapi.responses import JSONResponse +from slowapi.errors import RateLimitExceeded +from slowapi import Limiter +from slowapi.util import get_remote_address +from core.config import get_settings + +logger = logging.getLogger(__name__) + +settings = get_settings() + +# We use the same redis instance which is used to store build metadata +# and other cached data. To keep that data separate, we use db-1 of the +# redis instance instead of the default db-0. +REDIS_DB_NUMBER = 1 +limiter = Limiter( + key_func=get_remote_address, + storage_uri=f"redis://{settings.redis_host}:{settings.redis_port}/{REDIS_DB_NUMBER}", + strategy="fixed-window", +) + + +def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: + """ + Response to send when a rate limit is exception is raised + """ + response = JSONResponse( + {"detail": "Too many requests. Try again after some time."}, + status_code=429 + ) + response = request.app.state.limiter._inject_headers( + response, request.state.view_rate_limit + ) + return response diff --git a/web/main.py b/web/main.py index 4f7adbf0..a87d716a 100755 --- a/web/main.py +++ b/web/main.py @@ -11,12 +11,16 @@ from fastapi import FastAPI from fastapi.staticfiles import StaticFiles +from slowapi.errors import RateLimitExceeded +from slowapi.middleware import SlowAPIMiddleware from api.v1 import router as v1_router from ui import router as ui_router + from core.config import get_settings from core.startup import initialize_application from core.logging_config import setup_logging +from core.limiter import limiter, rate_limit_exceeded_handler import ap_git import metadata_manager @@ -90,6 +94,7 @@ async def lifespan(app: FastAPI): app.state.build_manager = build_mgr app.state.inbuilt_builder = inbuilt_builder app.state.inbuilt_builder_thread = inbuilt_builder_thread + app.state.limiter = limiter yield @@ -114,6 +119,10 @@ async def lifespan(app: FastAPI): lifespan=lifespan, ) +# SlowAPIMiddleware is used for rate limiting +app.add_middleware(SlowAPIMiddleware) +app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler) + # Mount static files WEB_ROOT = Path(__file__).resolve().parent app.mount( diff --git a/web/requirements.txt b/web/requirements.txt index 789a92ed..1862bbda 100644 --- a/web/requirements.txt +++ b/web/requirements.txt @@ -8,3 +8,4 @@ dill==0.3.8 packaging==25.0 jinja2==3.1.2 python-multipart==0.0.6 +slowapi==0.1.9 diff --git a/web/services/builds.py b/web/services/builds.py index 619d7733..556bfe62 100644 --- a/web/services/builds.py +++ b/web/services/builds.py @@ -42,15 +42,13 @@ def __init__( def create_build( self, - build_request: BuildRequest, - client_ip: str + build_request: BuildRequest ) -> BuildSubmitResponse: """ Create a new build request. Args: build_request: Build configuration - client_ip: Client IP address for rate limiting Returns: Simple response with build_id and URL @@ -149,8 +147,7 @@ def create_build( # Submit build build_id = self.manager.submit_build( - build_info=build_info, - client_ip=client_ip, + build_info=build_info ) # Return simple submission response