From 2dc3f9b82ebd86e08849f7c56183ddab35c58806 Mon Sep 17 00:00:00 2001 From: Michael Paul Date: Wed, 2 Apr 2025 21:03:12 -0400 Subject: [PATCH 1/6] Reimplement request body's __await__ based on __anext__ This decouples __await__ from the details of how the data is received from the client, by delegating those details to __anext__ instead. --- src/quart/wrappers/request.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/quart/wrappers/request.py b/src/quart/wrappers/request.py index 1cf2b2f..35b35cd 100644 --- a/src/quart/wrappers/request.py +++ b/src/quart/wrappers/request.py @@ -88,17 +88,17 @@ async def __anext__(self) -> bytes: return data def __await__(self) -> Generator[Any, None, Any]: - # Must check the _must_raise before and after waiting on the - # completion event as it may change whilst waiting and the - # event may not be set if there is already an issue. - if self._must_raise is not None: - raise self._must_raise + async def accumulate_data() -> bytes: + data = bytearray() - yield from self._complete.wait().__await__() + # Receive chunks of data from the client and build up the complete + # request body. + async for data_chunk in self: + data.extend(data_chunk) - if self._must_raise is not None: - raise self._must_raise - return bytes(self._data) + return bytes(data) + + return accumulate_data().__await__() def append(self, data: bytes) -> None: if data == b"" or self._must_raise is not None: From 3465df77810678da1d5c67093366472ebe9ad993 Mon Sep 17 00:00:00 2001 From: Michael Paul Date: Wed, 2 Apr 2025 21:29:20 -0400 Subject: [PATCH 2/6] Remove docstring entry for parameter that doesn't exist Request.__init__ doesn't take an awaitable body object as a parameter; it creates that object by itself. --- src/quart/wrappers/request.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/quart/wrappers/request.py b/src/quart/wrappers/request.py index 35b35cd..6cfa1fd 100644 --- a/src/quart/wrappers/request.py +++ b/src/quart/wrappers/request.py @@ -171,8 +171,6 @@ def __init__( root_path: The root path that should be prepended to all routes. http_version: The HTTP version of the request. - body: An awaitable future for the body data i.e. - ``data = await body`` max_content_length: The maximum length in bytes of the body (None implies no limit in Quart). body_timeout: The maximum time (seconds) to wait for the From 9bbe9ceeb95c8f6ac98187dd0df651778935b158 Mon Sep 17 00:00:00 2001 From: Michael Paul Date: Sat, 5 Apr 2025 10:37:24 -0400 Subject: [PATCH 3/6] Implement backpressure for HTTP request body Instead of ASGIHTTPConnection's receiver task appending the data to a byte array, it now puts chunks into a queue that the request body reads from. The queue is limited to a single item, so the receiver task can only put chunks into the queue as quickly as the handler task is taking them out. This limits the rate at which body chunks are obtained from the ASGI receive function, and if the ASGI server also supports pressure, this will limit the rate at which input is actually received from the client, so that large requests won't be buffered in memory if the client is able to upload faster than the application can consume the data. The Body class no longer has methods to append data or indicate completion, because it no longer owns the storage of the data; instead, those operations are now provided by the AsyncQueueIterator class, which provides an iterator that the body reads from. (For tests, there's also a make_test_body_chunks function that produces an iterable with a predefined sequence of chunks.) Because the body object only streams the data now instead of storing it, the request's get_data() method no longer expects the body object to provide the data more than once; instead, get_data() now stores the data on in the request object if caching is requested. --- src/quart/app.py | 3 +- src/quart/asgi.py | 15 ++++--- src/quart/testing/__init__.py | 2 + src/quart/testing/utils.py | 6 +++ src/quart/utils.py | 63 +++++++++++++++++++++++++++ src/quart/wrappers/request.py | 77 ++++++++++++++------------------- tests/test_app.py | 3 ++ tests/test_asgi.py | 2 +- tests/test_ctx.py | 5 +++ tests/test_formparser.py | 5 ++- tests/test_routing.py | 2 + tests/test_sessions.py | 2 + tests/wrappers/test_request.py | 58 ++++++++++++++++++------- tests/wrappers/test_response.py | 6 +++ 14 files changed, 178 insertions(+), 71 deletions(-) diff --git a/src/quart/app.py b/src/quart/app.py index a25933c..21e5a43 100644 --- a/src/quart/app.py +++ b/src/quart/app.py @@ -84,6 +84,7 @@ from .signals import websocket_tearing_down from .templating import _default_template_ctx_processor from .templating import Environment +from .testing import make_test_body_chunks from .testing import make_test_body_with_headers from .testing import make_test_headers_path_and_query_string from .testing import make_test_scope @@ -1363,10 +1364,10 @@ def test_request_context( headers, root_path, http_version, + body_chunks=make_test_body_chunks(request_body), send_push_promise=send_push_promise, scope=scope, ) - request.body.set_result(request_body) return self.request_context(request) def add_background_task(self, func: Callable, *args: Any, **kwargs: Any) -> None: diff --git a/src/quart/asgi.py b/src/quart/asgi.py index 7221d48..9ac995b 100644 --- a/src/quart/asgi.py +++ b/src/quart/asgi.py @@ -31,6 +31,7 @@ from .signals import websocket_received from .signals import websocket_sent from .typing import ResponseTypes +from .utils import AsyncQueueIterator from .utils import cancel_tasks from .utils import encode_headers from .utils import raise_task_exceptions @@ -46,12 +47,13 @@ class ASGIHTTPConnection: def __init__(self, app: Quart, scope: HTTPScope) -> None: self.app = app self.scope = scope + self.queue: AsyncQueueIterator[bytes] = AsyncQueueIterator(1) async def __call__( self, receive: ASGIReceiveCallable, send: ASGISendCallable ) -> None: request = self._create_request_from_scope(send) - receiver_task = asyncio.ensure_future(self.handle_messages(request, receive)) + receiver_task = asyncio.ensure_future(self.handle_messages(receive)) handler_task = asyncio.ensure_future(self.handle_request(request, send)) done, pending = await asyncio.wait( [handler_task, receiver_task], return_when=asyncio.FIRST_COMPLETED @@ -59,15 +61,15 @@ async def __call__( await cancel_tasks(pending) raise_task_exceptions(done) - async def handle_messages( - self, request: Request, receive: ASGIReceiveCallable - ) -> None: + async def handle_messages(self, receive: ASGIReceiveCallable) -> None: + queue = self.queue # for quicker access in the loop + while True: message = await receive() if message["type"] == "http.request": - request.body.append(message.get("body", b"")) + await queue.put(message.get("body", b"")) if not message.get("more_body", False): - request.body.set_complete() + queue.set_complete() elif message["type"] == "http.disconnect": return @@ -99,6 +101,7 @@ def _create_request_from_scope(self, send: ASGISendCallable) -> Request: self.scope["http_version"], max_content_length=self.app.config["MAX_CONTENT_LENGTH"], body_timeout=self.app.config["BODY_TIMEOUT"], + body_chunks=self.queue, send_push_promise=partial(self._send_push_promise, send), scope=self.scope, ) diff --git a/src/quart/testing/__init__.py b/src/quart/testing/__init__.py index 9b280bd..b49acc2 100644 --- a/src/quart/testing/__init__.py +++ b/src/quart/testing/__init__.py @@ -9,6 +9,7 @@ from .app import TestApp from .client import QuartClient from .connections import WebsocketResponseError +from .utils import make_test_body_chunks from .utils import make_test_body_with_headers from .utils import make_test_headers_path_and_query_string from .utils import make_test_scope @@ -35,6 +36,7 @@ def invoke(self, cli: Any = None, args: Any = None, **kwargs: Any) -> Any: # ty __all__ = ( + "make_test_body_chunks", "make_test_body_with_headers", "make_test_headers_path_and_query_string", "make_test_scope", diff --git a/src/quart/testing/utils.py b/src/quart/testing/utils.py index a39e3ef..071d8e0 100644 --- a/src/quart/testing/utils.py +++ b/src/quart/testing/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import AsyncIterator from typing import Any from typing import AnyStr from typing import cast @@ -218,6 +219,11 @@ def make_test_scope( return cast(Scope, scope) +async def make_test_body_chunks(*chunks: bytes) -> AsyncIterator[bytes]: + for chunk in chunks: + yield chunk + + async def no_op_push(path: str, headers: Headers) -> None: """A push promise sender that does nothing. diff --git a/src/quart/utils.py b/src/quart/utils.py index 3669295..7a2f1ce 100644 --- a/src/quart/utils.py +++ b/src/quart/utils.py @@ -24,6 +24,11 @@ from .typing import Event from .typing import FilePath +if sys.version_info >= (3, 10): + from typing import Self +else: + from typing_extensions import Self + if TYPE_CHECKING: from .wrappers.response import Response # noqa: F401 @@ -184,3 +189,61 @@ def raise_task_exceptions(tasks: set[asyncio.Task]) -> None: for task in tasks: if not task.cancelled() and task.exception() is not None: raise task.exception() + + +# Dummy type used in AsyncQueueIterator to wakeup an await without sending any +# data. (None isn't used for that, because the generic type T could allow None +# as valid data in the queue.) +class _AsyncQueueWakeup: + pass + + +# Items go in using an async queue interface, and come out via async iteration. +class AsyncQueueIterator(AsyncIterator[T]): + _queue: asyncio.Queue[T | _AsyncQueueWakeup] + _complete: bool + + def __init__(self, maxsize: int = 0) -> None: + self._queue = asyncio.Queue(maxsize) + self._complete = False # In Python 3.13, use queue's shutdown() instead + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> T: + while not (self._queue.empty() and self._complete): + item = await self._queue.get() + + if not isinstance(item, _AsyncQueueWakeup): + return item + + raise StopAsyncIteration() + + def empty(self) -> bool: + return self._queue.empty() + + def full(self) -> bool: + return self._queue.full() + + def complete(self) -> bool: + return self._complete + + def _reject_if_complete(self) -> None: + if self._complete: + raise RuntimeError("already complete") + + async def put(self, item: T) -> None: + self._reject_if_complete() + + await self._queue.put(item) + + def put_nowait(self, item: T) -> None: + self._reject_if_complete() + + self._queue.put_nowait(item) + + def set_complete(self) -> None: + self._complete = True + + if self._queue.empty(): # so a get() might be waiting + self._queue.put_nowait(_AsyncQueueWakeup()) diff --git a/src/quart/wrappers/request.py b/src/quart/wrappers/request.py index 6cfa1fd..a093e19 100644 --- a/src/quart/wrappers/request.py +++ b/src/quart/wrappers/request.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from collections.abc import AsyncIterator from collections.abc import Awaitable from collections.abc import Generator from typing import Any @@ -49,11 +50,13 @@ class Body: """ def __init__( - self, expected_content_length: int | None, max_content_length: int | None + self, + chunks: AsyncIterator[bytes], + expected_content_length: int | None, + max_content_length: int | None, ) -> None: - self._data = bytearray() - self._complete: asyncio.Event = asyncio.Event() - self._has_data: asyncio.Event = asyncio.Event() + self._chunks = chunks + self._received_content_length = 0 self._max_content_length = max_content_length # Exceptions must be raised within application (not ASGI) # calls, this is achieved by having the ASGI methods set this @@ -73,18 +76,16 @@ async def __anext__(self) -> bytes: if self._must_raise is not None: raise self._must_raise - # if we got all of the data in the first shot, then self._complete is - # set and self._has_data will not get set again, so skip the await - # if we already have completed everything - if not self._complete.is_set(): - await self._has_data.wait() + data = await self._chunks.__anext__() + + self._received_content_length += len(data) - if self._complete.is_set() and len(self._data) == 0: - raise StopAsyncIteration() + if ( + self._max_content_length is not None + and self._received_content_length > self._max_content_length + ): + raise RequestEntityTooLarge() - data = bytes(self._data) - self._data.clear() - self._has_data.clear() return data def __await__(self) -> Generator[Any, None, Any]: @@ -100,30 +101,6 @@ async def accumulate_data() -> bytes: return accumulate_data().__await__() - def append(self, data: bytes) -> None: - if data == b"" or self._must_raise is not None: - return - self._data.extend(data) - self._has_data.set() - if ( - self._max_content_length is not None - and len(self._data) > self._max_content_length - ): - self._must_raise = RequestEntityTooLarge() - self.set_complete() - - def set_complete(self) -> None: - self._complete.set() - self._has_data.set() - - def set_result(self, data: bytes) -> None: - """Convenience method, mainly for testing.""" - self.append(data) - self.set_complete() - - def clear(self) -> None: - self._data.clear() - class Request(BaseRequestWebsocket): """This class represents a request. @@ -158,6 +135,7 @@ def __init__( *, max_content_length: int | None = None, body_timeout: int | None = None, + body_chunks: AsyncIterator[bytes], send_push_promise: Callable[[str, Headers], Awaitable[None]], ) -> None: """Create a request object. @@ -173,6 +151,8 @@ def __init__( http_version: The HTTP version of the request. max_content_length: The maximum length in bytes of the body (None implies no limit in Quart). + body_chunks: An async iterable that provides the request body as a + sequence of data chunks. body_timeout: The maximum time (seconds) to wait for the body before timing out. send_push_promise: An awaitable to send a push promise based @@ -183,7 +163,12 @@ def __init__( method, scheme, path, query_string, headers, root_path, http_version, scope ) self.body_timeout = body_timeout - self.body = self.body_class(self.content_length, max_content_length) + self.body = self.body_class( + body_chunks, + self.content_length, + max_content_length, + ) + self._cached_data: str | bytes | None = None self._cached_json: dict[bool, Any] = {False: Ellipsis, True: Ellipsis} self._form: MultiDict | None = None self._files: MultiDict | None = None @@ -269,6 +254,9 @@ async def get_data( parse_form_data: Parse the data as form data first, return any remaining data. """ + if self._cached_data is not None: + return self._cached_data + if parse_form_data: await self._load_form_data() @@ -277,13 +265,12 @@ async def get_data( except asyncio.TimeoutError as e: raise RequestTimeout() from e else: - if not cache: - self.body.clear() + data = raw_data.decode() if as_text else raw_data - if as_text: - return raw_data.decode() - else: - return raw_data + if cache: + self._cached_data = data + + return data @property async def values(self) -> CombinedMultiDict: diff --git a/tests/test_app.py b/tests/test_app.py index f5f2dcb..9ff99d4 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -17,6 +17,7 @@ from quart.globals import websocket from quart.sessions import SecureCookieSession from quart.sessions import SessionInterface +from quart.testing import make_test_body_chunks from quart.testing import no_op_push from quart.testing import WebsocketResponseError from quart.typing import ResponseReturnValue @@ -273,6 +274,7 @@ async def index() -> NoReturn: "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) with pytest.raises(asyncio.CancelledError): @@ -390,6 +392,7 @@ async def exception() -> ResponseReturnValue: "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) ) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index a491fa0..8bc7df2 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -118,7 +118,7 @@ async def receive() -> ASGIReceiveEvent: # This test fails with a timeout error if the request body is not received # within 1 second - receiver_task = asyncio.ensure_future(connection.handle_messages(request, receive)) + receiver_task = asyncio.ensure_future(connection.handle_messages(receive)) body = await asyncio.wait_for(request.body, timeout=1) receiver_task.cancel() diff --git a/tests/test_ctx.py b/tests/test_ctx.py index 7450b31..d117762 100644 --- a/tests/test_ctx.py +++ b/tests/test_ctx.py @@ -22,6 +22,7 @@ from quart.globals import request from quart.globals import websocket from quart.routing import QuartRule +from quart.testing import make_test_body_chunks from quart.testing import make_test_headers_path_and_query_string from quart.testing import no_op_push from quart.wrappers import Request @@ -42,6 +43,7 @@ async def test_request_context_match(http_scope: HTTPScope) -> None: "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) async with RequestContext(app, request): @@ -63,6 +65,7 @@ async def test_bad_request_if_websocket_route(http_scope: HTTPScope) -> None: "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) async with RequestContext(app, request): @@ -83,6 +86,7 @@ async def test_after_this_request(http_scope: HTTPScope) -> None: "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ), ) as context: @@ -102,6 +106,7 @@ async def test_has_request_context(http_scope: HTTPScope) -> None: "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) async with RequestContext(Quart(__name__), request): diff --git a/tests/test_formparser.py b/tests/test_formparser.py index c5e85f2..102b47b 100644 --- a/tests/test_formparser.py +++ b/tests/test_formparser.py @@ -4,6 +4,7 @@ from werkzeug.exceptions import RequestEntityTooLarge from quart.formparser import MultiPartParser +from quart.testing import make_test_body_chunks from quart.wrappers.request import Body @@ -11,8 +12,8 @@ async def test_multipart_max_form_memory_size() -> None: """max_form_memory_size is tracked across multiple data events.""" data = b"--bound\r\nContent-Disposition: form-field; name=a\r\n\r\n" data += b"a" * 15 + b"\r\n--bound--" - body = Body(None, None) - body.set_result(data) + body_chunks = make_test_body_chunks(data) + body = Body(body_chunks, None, None) # The buffer size is less than the max size, so multiple data events will be # returned. The field size is greater than the max. parser = MultiPartParser(max_form_memory_size=10, buffer_size=5) diff --git a/tests/test_routing.py b/tests/test_routing.py index 9756599..6f2b264 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -7,6 +7,7 @@ from werkzeug.datastructures import Headers from quart.routing import QuartMap +from quart.testing import make_test_body_chunks from quart.testing import no_op_push from quart.wrappers.request import Request @@ -28,6 +29,7 @@ async def test_bind_warning( "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) diff --git a/tests/test_sessions.py b/tests/test_sessions.py index 130b0e6..edb9afb 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -8,6 +8,7 @@ from quart.app import Quart from quart.sessions import SecureCookieSession from quart.sessions import SecureCookieSessionInterface +from quart.testing import make_test_body_chunks from quart.testing import no_op_push from quart.wrappers import Request from quart.wrappers import Response @@ -32,6 +33,7 @@ async def test_secure_cookie_session_interface_open_session( "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) request.headers["Cookie"] = response.headers["Set-Cookie"] diff --git a/tests/wrappers/test_request.py b/tests/wrappers/test_request.py index e2f0267..d711717 100644 --- a/tests/wrappers/test_request.py +++ b/tests/wrappers/test_request.py @@ -9,31 +9,37 @@ from werkzeug.exceptions import RequestEntityTooLarge from werkzeug.exceptions import RequestTimeout +from quart.testing import make_test_body_chunks from quart.testing import no_op_push +from quart.utils import AsyncQueueIterator from quart.wrappers.request import Body from quart.wrappers.request import Request -async def _fill_body(body: Body, semaphore: asyncio.Semaphore, limit: int) -> None: +async def _fill_body( + body_chunks: AsyncQueueIterator[bytes], semaphore: asyncio.Semaphore, limit: int +) -> None: for number in range(limit): - body.append(b"%d" % number) + await body_chunks.put(b"%d" % number) await semaphore.acquire() - body.set_complete() + body_chunks.set_complete() async def test_full_body() -> None: - body = Body(None, None) + body_chunks: AsyncQueueIterator[bytes] = AsyncQueueIterator(1) + body = Body(body_chunks, None, None) limit = 3 semaphore = asyncio.Semaphore(limit) - asyncio.ensure_future(_fill_body(body, semaphore, limit)) + asyncio.ensure_future(_fill_body(body_chunks, semaphore, limit)) assert b"012" == await body async def test_body_streaming() -> None: - body = Body(None, None) + body_chunks: AsyncQueueIterator[bytes] = AsyncQueueIterator(1) + body = Body(body_chunks, None, None) limit = 3 semaphore = asyncio.Semaphore(0) - asyncio.ensure_future(_fill_body(body, semaphore, limit)) + asyncio.ensure_future(_fill_body(body_chunks, semaphore, limit)) index = 0 async for data in body: semaphore.release() @@ -42,10 +48,22 @@ async def test_body_streaming() -> None: assert b"" == await body +async def test_body_streaming_backpressure() -> None: + body_chunks: AsyncQueueIterator[bytes] = AsyncQueueIterator(1) + body = Body(body_chunks, None, None) + limit = 3 + semaphore = asyncio.Semaphore(2) # will be locked if more than 1 chunk queued + asyncio.ensure_future(_fill_body(body_chunks, semaphore, limit)) + async for _ in body: + assert not semaphore.locked() # only 1 chunk was accepted from source + semaphore.release() + + async def test_body_stream_single_chunk() -> None: - body = Body(None, None) - body.append(b"data") - body.set_complete() + body_chunks: AsyncQueueIterator[bytes] = AsyncQueueIterator(1) + body = Body(body_chunks, None, None) + body_chunks.put_nowait(b"data") + body_chunks.set_complete() async def _check_data() -> None: async for data in body: @@ -55,9 +73,10 @@ async def _check_data() -> None: async def test_body_streaming_no_data() -> None: - body = Body(None, None) + body_chunks: AsyncQueueIterator[bytes] = AsyncQueueIterator(1) + body = Body(body_chunks, None, None) semaphore = asyncio.Semaphore(0) - asyncio.ensure_future(_fill_body(body, semaphore, 0)) + asyncio.ensure_future(_fill_body(body_chunks, semaphore, 0)) async for _ in body: # noqa: F841 raise AssertionError("Should not reach this line") assert b"" == await body @@ -65,8 +84,9 @@ async def test_body_streaming_no_data() -> None: async def test_body_exceeds_max_content_length() -> None: max_content_length = 5 - body = Body(None, max_content_length) - body.append(b" " * (max_content_length + 1)) + body_chunks: AsyncQueueIterator[bytes] = AsyncQueueIterator(1) + body = Body(body_chunks, None, max_content_length) + body_chunks.put_nowait(b" " * (max_content_length + 1)) with pytest.raises(RequestEntityTooLarge): await body @@ -85,6 +105,7 @@ async def test_request_exceeds_max_content_length(http_scope: HTTPScope) -> None "1.1", http_scope, max_content_length=max_content_length, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) with pytest.raises(RequestEntityTooLarge): @@ -92,6 +113,7 @@ async def test_request_exceeds_max_content_length(http_scope: HTTPScope) -> None async def test_request_get_data_timeout(http_scope: HTTPScope) -> None: + body_chunks: AsyncQueueIterator[bytes] = AsyncQueueIterator(1) request = Request( "POST", "http", @@ -102,6 +124,7 @@ async def test_request_get_data_timeout(http_scope: HTTPScope) -> None: "1.1", http_scope, body_timeout=1, + body_chunks=body_chunks, send_push_promise=no_op_push, ) with pytest.raises(RequestTimeout): @@ -115,6 +138,7 @@ async def test_request_get_data_timeout(http_scope: HTTPScope) -> None: async def test_request_values( method: str, expected: list[str], http_scope: HTTPScope ) -> None: + body_chunks: AsyncQueueIterator[bytes] = AsyncQueueIterator(1) request = Request( method, "http", @@ -126,10 +150,11 @@ async def test_request_values( "", "1.1", http_scope, + body_chunks=body_chunks, send_push_promise=no_op_push, ) - request.body.append(urlencode({"a": "d"}).encode()) - request.body.set_complete() + body_chunks.put_nowait(urlencode({"a": "d"}).encode()) + body_chunks.set_complete() assert (await request.values).getlist("a") == expected @@ -157,6 +182,7 @@ async def _push(path: str, headers: Headers) -> None: "", "2", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=_push, ) await request.send_push_promise("/") diff --git a/tests/wrappers/test_response.py b/tests/wrappers/test_response.py index 15980ee..43605b6 100644 --- a/tests/wrappers/test_response.py +++ b/tests/wrappers/test_response.py @@ -14,6 +14,7 @@ from werkzeug.datastructures import Headers from werkzeug.exceptions import RequestedRangeNotSatisfiable +from quart.testing import make_test_body_chunks from quart.testing import no_op_push from quart.typing import HTTPScope from quart.wrappers import Request @@ -96,6 +97,7 @@ async def test_response_make_conditional(http_scope: HTTPScope) -> None: "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) response = Response(b"abcdef") @@ -119,6 +121,7 @@ async def test_response_make_conditional_no_condition(http_scope: HTTPScope) -> "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) response = Response(b"abcdef") @@ -137,6 +140,7 @@ async def test_response_make_conditional_out_of_bound(http_scope: HTTPScope) -> "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) response = Response(b"abcdef") @@ -157,6 +161,7 @@ async def test_response_make_conditional_not_modified(http_scope: HTTPScope) -> "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) await response.make_conditional(request) @@ -181,6 +186,7 @@ async def test_response_make_conditional_not_satisfiable( "", "1.1", http_scope, + body_chunks=make_test_body_chunks(), send_push_promise=no_op_push, ) response = Response(b"abcdef") From c43b6e4b5bec84c94db14a37c38e11112a729704 Mon Sep 17 00:00:00 2001 From: Michael Paul Date: Sat, 5 Apr 2025 12:15:12 -0400 Subject: [PATCH 4/6] Implement backpressure for WebSocket messages WebSocket already uses a queue for received messages, but like HTTP, the queue needs to be limited to a single item so that messages will be accepted from the client no faster than the application is able to use them. --- src/quart/asgi.py | 2 +- tests/test_asgi.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/src/quart/asgi.py b/src/quart/asgi.py index 9ac995b..1b3f27b 100644 --- a/src/quart/asgi.py +++ b/src/quart/asgi.py @@ -183,7 +183,7 @@ class ASGIWebsocketConnection: def __init__(self, app: Quart, scope: WebsocketScope) -> None: self.app = app self.scope = scope - self.queue: asyncio.Queue = asyncio.Queue() + self.queue: asyncio.Queue = asyncio.Queue(1) self._accepted = False self._closed = False diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 8bc7df2..394fc60 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -12,6 +12,7 @@ from werkzeug.datastructures import Headers from quart import Quart +from quart import websocket from quart.asgi import _convert_version from quart.asgi import _handle_exception from quart.asgi import ASGIHTTPConnection @@ -160,6 +161,61 @@ async def send(message: ASGISendEvent) -> None: await asyncio.wait_for(connection(receive, send), timeout=1) +async def test_websocket_backpressure() -> None: + app = Quart(__name__) + scope: WebsocketScope = { + "type": "websocket", + "asgi": {}, + "http_version": "1.1", + "scheme": "wss", + "path": "/", + "raw_path": b"/", + "query_string": b"", + "root_path": "", + "headers": [(b"host", b"quart")], + "client": ("127.0.0.1", 80), + "server": None, + "subprotocols": [], + "extensions": {"websocket.http.response": {}}, + "state": {}, # type: ignore[typeddict-item] + } + connection = ASGIWebsocketConnection(app, scope) + + count = 3 + + queue: asyncio.Queue = asyncio.Queue() + queue.put_nowait({"type": "websocket.connect"}) + for i in range(count): + queue.put_nowait({"type": "websocket.receive", "text": str(i)}) + queue.put_nowait({"type": "websocket.disconnect"}) + + async def receive() -> ASGIReceiveEvent: + return await queue.get() + + async def send(message: ASGISendEvent) -> None: + pass + + size_checks: list[tuple[int, int]] = [] + + @app.websocket("/") + async def ws() -> None: + while True: + n = int(await websocket.receive()) + + size_check = (n, queue.qsize()) + size_checks.append(size_check) + + await connection(receive, send) + + assert len(size_checks) == count + for n, qsize in size_checks: + # At each step, the queue contains the remaining data messages except + # for the one that's just been received and the next one after it (that + # one's been moved to the connection's internal queue), plus the + # disconnect message. + assert qsize == (count - n - 2) + 1 + + def test_http_path_from_absolute_target() -> None: app = Quart(__name__) scope: HTTPScope = { From 6936587c4c4802654f5fd63f382fc0543e4b94c8 Mon Sep 17 00:00:00 2001 From: Michael Paul Date: Sat, 5 Apr 2025 14:31:26 -0400 Subject: [PATCH 5/6] typing.Self requires Python 3.11, not 3.10 --- src/quart/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/quart/utils.py b/src/quart/utils.py index 7a2f1ce..00614c8 100644 --- a/src/quart/utils.py +++ b/src/quart/utils.py @@ -24,7 +24,7 @@ from .typing import Event from .typing import FilePath -if sys.version_info >= (3, 10): +if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self From 9bb2ffe4a82e63ed567aef1984841bb008050c50 Mon Sep 17 00:00:00 2001 From: Michael Paul Date: Sat, 5 Apr 2025 16:49:15 -0400 Subject: [PATCH 6/6] Bump typing-extensions dependency to include Python 3.10 This is needed for use of typing_extensions.Self on Python 3.10, since typing.Self doesn't exist before Python 3.11. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b15e778..ad4cb2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "itsdangerous", "jinja2", "markupsafe", - "typing-extensions; python_version < '3.10'", + "typing-extensions; python_version < '3.11'", "werkzeug>=3.0", ]