1010from collections import defaultdict
1111
1212from google .protobuf .empty_pb2 import Empty
13- from grpc import ServicerContext
13+ from grpc . aio import ServicerContext
1414
1515from mcpd_plugins import BasePlugin , serve
1616from mcpd_plugins .v1 .plugins .plugin_pb2 import (
@@ -41,6 +41,7 @@ def __init__(self, requests_per_minute: int = 60):
4141 # Track tokens for each client IP.
4242 self .buckets : dict [str , float ] = defaultdict (lambda : float (requests_per_minute ))
4343 self .last_update : dict [str , float ] = defaultdict (time .time )
44+ self .locks : dict [str , asyncio .Lock ] = defaultdict (asyncio .Lock )
4445
4546 async def GetMetadata (self , request : Empty , context : ServicerContext ) -> Metadata :
4647 """Return plugin metadata."""
@@ -57,33 +58,34 @@ async def GetCapabilities(self, request: Empty, context: ServicerContext) -> Cap
5758 async def HandleRequest (self , request : HTTPRequest , context : ServicerContext ) -> HTTPResponse :
5859 """Apply rate limiting based on client IP."""
5960 client_ip = request .remote_addr or "unknown"
60- logger .info (f"Rate limit check for { client_ip } : { request .method } { request .url } " )
61-
62- # Refill tokens based on time elapsed.
63- now = time .time ()
64- elapsed = now - self .last_update [client_ip ]
65- self .buckets [client_ip ] = min (
66- self .requests_per_minute ,
67- self .buckets [client_ip ] + elapsed * self .rate_per_second ,
68- )
69- self .last_update [client_ip ] = now
70-
71- # Check if client has tokens available.
72- if self .buckets [client_ip ] < 1.0 :
73- logger .warning (f"Rate limit exceeded for { client_ip } " )
74- return self ._rate_limit_response (client_ip )
75-
76- # Consume one token.
77- self .buckets [client_ip ] -= 1.0
78- logger .info (f"Request allowed for { client_ip } (tokens remaining: { self .buckets [client_ip ]:.2f} )" )
79-
80- # Add rate limit headers to response.
81- response = HTTPResponse (** {"continue" : True })
82- response .modified_request .CopyFrom (request )
83- response .headers ["X-RateLimit-Limit" ] = str (self .requests_per_minute )
84- response .headers ["X-RateLimit-Remaining" ] = str (int (self .buckets [client_ip ]))
85-
86- return response
61+ logger .info ("Rate limit check for %s: %s %s" , client_ip , request .method , request .url )
62+
63+ async with self .locks [client_ip ]:
64+ # Refill tokens based on time elapsed.
65+ now = time .time ()
66+ elapsed = now - self .last_update [client_ip ]
67+ self .buckets [client_ip ] = min (
68+ self .requests_per_minute ,
69+ self .buckets [client_ip ] + elapsed * self .rate_per_second ,
70+ )
71+ self .last_update [client_ip ] = now
72+
73+ # Check if client has tokens available.
74+ if self .buckets [client_ip ] < 1.0 :
75+ logger .warning ("Rate limit exceeded for %s" , client_ip )
76+ return self ._rate_limit_response (client_ip )
77+
78+ # Consume one token.
79+ self .buckets [client_ip ] -= 1.0
80+ logger .info ("Request allowed for %s (tokens remaining: %.2f)" , client_ip , self .buckets [client_ip ])
81+
82+ # Add rate limit headers to response.
83+ response = HTTPResponse (** {"continue" : True })
84+ response .modified_request .CopyFrom (request )
85+ response .headers ["X-RateLimit-Limit" ] = str (self .requests_per_minute )
86+ response .headers ["X-RateLimit-Remaining" ] = str (int (self .buckets [client_ip ]))
87+
88+ return response
8789
8890 def _rate_limit_response (self , client_ip : str ) -> HTTPResponse :
8991 """Create a 429 Too Many Requests response."""
0 commit comments