Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions fasthttp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,15 @@ async def lifespan(app: FastHTTP):
self.secret_key = secret_key or secrets.token_bytes(32)
self.security = Security(secret_key=self.secret_key) if security else None
self.startup_uuid = None
if self.generate_startup_uuid and self.startup_uuid_version == "v4":
self.startup_uuid = str(uuid.uuid4())
if self.generate_startup_uuid:
if self.startup_uuid_version == "v7":
import sys
if sys.version_info >= (3, 13):
self.startup_uuid = str(uuid.uuid7())
else:
self.startup_uuid = str(uuid.uuid4())
else:
self.startup_uuid = str(uuid.uuid4())

self.client = HTTPClient(
self.request_configs,
Expand Down Expand Up @@ -875,7 +882,10 @@ def _log_result(
)

if route.response_model:
json_data = result.json()
try:
json_data = result.json()
except Exception:
json_data = None
if json_data is not None:
if get_origin(route.response_model) is list:
item_model = get_args(route.response_model)[0]
Expand Down Expand Up @@ -1223,11 +1233,8 @@ def _find_route(self, url: str, method: str) -> Route | None:
normalized_url = self._normalize_url(url)
for route in self.fasthttp.routes:
route_normalized = self._normalize_url(route.url)
if route.method.upper() == method.upper():
if route_normalized == normalized_url:
return route
if route_normalized in normalized_url or normalized_url in route_normalized:
return route
if route.method.upper() == method.upper() and route_normalized == normalized_url:
return route
return None

async def _handle_proxy(
Expand Down Expand Up @@ -1300,7 +1307,7 @@ async def _handle_proxy(
else:
result["json"] = json_data
except Exception as e:
print(f"DEBUG: validation error={e}")
self.fasthttp.logger.debug("validation error=%s", e)

await self._send_json(send, result)

Expand Down
1 change: 1 addition & 0 deletions fasthttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _validate_request(self, route: Route) -> bool:
return False

async def _prepare_config(self, route: Route, config: dict) -> dict:
config = dict(config)
headers = dict(config.get("headers") or {})
headers.setdefault("User-Agent", f"fasthttp/{__version__}")

Expand Down
8 changes: 3 additions & 5 deletions fasthttp/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@ def __init__(
self.cache_methods = cache_methods or ["GET"]
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._lock = asyncio.Lock()
self._cached_response: Response | None = None

def _generate_key(self, route: "Route") -> str:
key_data = f"{route.method}:{route.url}:{json.dumps(route.params or {}, sort_keys=True)}"
Expand Down Expand Up @@ -551,7 +550,7 @@ async def before_request(

if time.time() < entry.expires_at:
self._cache.move_to_end(key)
self._cached_response = entry.response
config["_cache_hit"] = entry.response
return config
del self._cache[key]

Expand All @@ -578,9 +577,8 @@ async def after_response(
if route.method not in self.cache_methods:
return response

if self._cached_response is not None:
cached = self._cached_response
self._cached_response = None
cached = config.get("_cache_hit")
if cached is not None:
return cached

key = self._generate_key(route)
Expand Down
2 changes: 1 addition & 1 deletion fasthttp/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def path_params(
]:
return {}

def json(self) -> dict[str, Any]:
def json(self) -> Any:
"""
Parse the response body as JSON.

Expand Down
12 changes: 6 additions & 6 deletions fasthttp/security/limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ def __init__(
config: LimitsConfig | None = None
) -> None:
self._config = config or LimitsConfig()
self._semaphore: asyncio.Semaphore | None = None
self._semaphore: asyncio.Semaphore | None = (
asyncio.Semaphore(self._config.max_concurrent_requests)
if self._config.max_concurrent_requests > 0
else None
)
self._last_request_time = 0.0
self._cooldown_lock = asyncio.Lock()

Expand Down Expand Up @@ -47,11 +51,7 @@ def validate_url_length(self, url: str) -> bool:
return len(url) <= self._config.max_url_length

async def acquire(self) -> None:
if self._config.max_concurrent_requests > 0:
if not self._semaphore:
self._semaphore = asyncio.Semaphore(
self._config.max_concurrent_requests
)
if self._semaphore:
await self._semaphore.acquire()

def release(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion fasthttp/security/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def validate_response(
text = content.decode("utf-8", errors="ignore")
xss_check = self.detect_xss(text)
if xss_check[0]:
return xss_check
return False, xss_check[1]
except Exception:
pass

Expand Down
2 changes: 1 addition & 1 deletion fasthttp/security/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def secret_key(self) -> bytes:
return self._signer._secret_key

async def pre_request(self, url: str, method: str) -> None:
self._ssrf.validate_request(url)
await self._ssrf.validate_request(url)

if not self._limits.validate_url_length(url):
raise SecurityError(f"URL too long: {len(url)} chars")
Expand Down
9 changes: 5 additions & 4 deletions fasthttp/security/ssrf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import ipaddress
import socket
from urllib.parse import urlparse
Expand Down Expand Up @@ -41,7 +42,7 @@ class SSRFProtection:
def __init__(self):
self._dns_cache = {}

def check_url(self, url: str) -> bool:
async def check_url(self, url: str) -> bool:
parsed = urlparse(url)
hostname = parsed.hostname

Expand All @@ -60,7 +61,7 @@ def check_url(self, url: str) -> bool:
return False

try:
ip_str = socket.gethostbyname(hostname)
ip_str = await asyncio.to_thread(socket.gethostbyname, hostname)
ip = ipaddress.ip_address(ip_str)

if ip.is_loopback:
Expand All @@ -83,8 +84,8 @@ def check_url(self, url: str) -> bool:

return True

def validate_request(self, url: str) -> None:
if not self.check_url(url):
async def validate_request(self, url: str) -> None:
if not await self.check_url(url):
raise SSRFBlockedError(f"SSRF protection blocked request to: {url}")


Expand Down
Loading