|
| 1 | +"""Rate limiting plugin using a simple token bucket algorithm. |
| 2 | +
|
| 3 | +This example demonstrates stateful request processing and rate limiting. |
| 4 | +""" |
| 5 | + |
| 6 | +import asyncio |
| 7 | +import logging |
| 8 | +import time |
| 9 | +from collections import defaultdict |
| 10 | + |
| 11 | +from google.protobuf.empty_pb2 import Empty |
| 12 | +from grpc import ServicerContext |
| 13 | + |
| 14 | +from mcpd_plugins import BasePlugin, serve |
| 15 | +from mcpd_plugins.v1.plugins.plugin_pb2 import ( |
| 16 | + FLOW_REQUEST, |
| 17 | + Capabilities, |
| 18 | + HTTPRequest, |
| 19 | + HTTPResponse, |
| 20 | + Metadata, |
| 21 | +) |
| 22 | + |
| 23 | +logging.basicConfig(level=logging.INFO) |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
| 26 | + |
| 27 | +class RateLimitPlugin(BasePlugin): |
| 28 | + """Plugin that implements rate limiting using token bucket algorithm.""" |
| 29 | + |
| 30 | + def __init__(self, requests_per_minute: int = 60): |
| 31 | + """Initialize the rate limiter. |
| 32 | +
|
| 33 | + Args: |
| 34 | + requests_per_minute: Maximum requests allowed per minute per client. |
| 35 | + """ |
| 36 | + super().__init__() |
| 37 | + self.requests_per_minute = requests_per_minute |
| 38 | + self.rate_per_second = requests_per_minute / 60.0 |
| 39 | + |
| 40 | + # Track tokens for each client IP. |
| 41 | + self.buckets: dict[str, float] = defaultdict(lambda: float(requests_per_minute)) |
| 42 | + self.last_update: dict[str, float] = defaultdict(time.time) |
| 43 | + |
| 44 | + async def GetMetadata(self, request: Empty, context: ServicerContext) -> Metadata: |
| 45 | + """Return plugin metadata.""" |
| 46 | + return Metadata( |
| 47 | + name="rate-limit-plugin", |
| 48 | + version="1.0.0", |
| 49 | + description=f"Rate limits requests to {self.requests_per_minute} per minute per client", |
| 50 | + ) |
| 51 | + |
| 52 | + async def GetCapabilities(self, request: Empty, context: ServicerContext) -> Capabilities: |
| 53 | + """Declare support for request flow.""" |
| 54 | + return Capabilities(flows=[FLOW_REQUEST]) |
| 55 | + |
| 56 | + async def HandleRequest(self, request: HTTPRequest, context: ServicerContext) -> HTTPResponse: |
| 57 | + """Apply rate limiting based on client IP.""" |
| 58 | + client_ip = request.remote_addr or "unknown" |
| 59 | + logger.info(f"Rate limit check for {client_ip}: {request.method} {request.url}") |
| 60 | + |
| 61 | + # Refill tokens based on time elapsed. |
| 62 | + now = time.time() |
| 63 | + elapsed = now - self.last_update[client_ip] |
| 64 | + self.buckets[client_ip] = min( |
| 65 | + self.requests_per_minute, |
| 66 | + self.buckets[client_ip] + elapsed * self.rate_per_second, |
| 67 | + ) |
| 68 | + self.last_update[client_ip] = now |
| 69 | + |
| 70 | + # Check if client has tokens available. |
| 71 | + if self.buckets[client_ip] < 1.0: |
| 72 | + logger.warning(f"Rate limit exceeded for {client_ip}") |
| 73 | + return self._rate_limit_response(client_ip) |
| 74 | + |
| 75 | + # Consume one token. |
| 76 | + self.buckets[client_ip] -= 1.0 |
| 77 | + logger.info(f"Request allowed for {client_ip} (tokens remaining: {self.buckets[client_ip]:.2f})") |
| 78 | + |
| 79 | + # Add rate limit headers to response. |
| 80 | + response = HTTPResponse(**{"continue": True}) |
| 81 | + response.modified_request.CopyFrom(request) |
| 82 | + response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute) |
| 83 | + response.headers["X-RateLimit-Remaining"] = str(int(self.buckets[client_ip])) |
| 84 | + |
| 85 | + return response |
| 86 | + |
| 87 | + def _rate_limit_response(self, client_ip: str) -> HTTPResponse: |
| 88 | + """Create a 429 Too Many Requests response.""" |
| 89 | + # Calculate retry-after in seconds. |
| 90 | + tokens_needed = 1.0 - self.buckets[client_ip] |
| 91 | + retry_after = int(tokens_needed / self.rate_per_second) + 1 |
| 92 | + |
| 93 | + response = HTTPResponse( |
| 94 | + **{"continue": False}, |
| 95 | + status_code=429, |
| 96 | + body=b'{"error": "Rate limit exceeded"}', |
| 97 | + ) |
| 98 | + response.headers["Content-Type"] = "application/json" |
| 99 | + response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute) |
| 100 | + response.headers["X-RateLimit-Remaining"] = "0" |
| 101 | + response.headers["Retry-After"] = str(retry_after) |
| 102 | + |
| 103 | + return response |
| 104 | + |
| 105 | + |
| 106 | +if __name__ == "__main__": |
| 107 | + asyncio.run(serve(RateLimitPlugin(requests_per_minute=60))) |
0 commit comments