diff --git a/conformance/client_config.yaml b/conformance/client_config.yaml index ebf7e0f..22fb4f4 100644 --- a/conformance/client_config.yaml +++ b/conformance/client_config.yaml @@ -14,7 +14,7 @@ features: - STREAM_TYPE_CLIENT_STREAM - STREAM_TYPE_SERVER_STREAM - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM - - STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM + # - STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM supports_h2c: true supports_tls: true diff --git a/conformance/client_known_failing.yaml b/conformance/client_known_failing.yaml index 046b5a9..8b13789 100644 --- a/conformance/client_known_failing.yaml +++ b/conformance/client_known_failing.yaml @@ -1,2 +1 @@ -# Cancellation is not supported yet -Client Cancellation/** + diff --git a/conformance/client_runner.py b/conformance/client_runner.py index ea04bbb..b5203a8 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -6,7 +6,6 @@ import ssl import struct import sys -import time import traceback from collections.abc import AsyncGenerator from typing import Any @@ -142,6 +141,17 @@ def to_pb_headers(headers: Headers) -> list[service_pb2.Header]: ] +def to_connect_headers(pb_headers: RepeatedCompositeFieldContainer[service_pb2.Header]) -> Headers: + headers = Headers() + for h in pb_headers: + if key := headers.get(h.name.lower()): + headers[key] = f"{headers[key]}, {', '.join(h.value)}" + else: + headers[h.name.lower()] = ", ".join(h.value) + + return headers + + async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_compat_pb2.ClientCompatResponse: """Handle a client compatibility request and returns a response. @@ -204,9 +214,6 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c url = f"{proto}://{msg.host}:{msg.port}" - if msg.request_delay_ms > 0: - time.sleep(msg.request_delay_ms / 1000.0) - async with AsyncClientSession(http1=http1, http2=http2, ssl_context=ssl_context) as session: payloads = [] try: @@ -216,20 +223,28 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c client = service_connect.ConformanceServiceClient(base_url=url, session=session, options=options) if msg.stream_type == config_pb2.STREAM_TYPE_UNARY: + if msg.request_delay_ms > 0: + await asyncio.sleep(msg.request_delay_ms / 1000) + + abort_event = asyncio.Event() req = await anext(reqs) - header = Headers() - for h in msg.request_headers: - if key := header.get(h.name.lower()): - header[key] = f"{header[key]}, {', '.join(h.value)}" - else: - header[h.name.lower()] = ", ".join(h.value) + if msg.cancel.after_close_send_ms > 0: + + async def delayed_abort() -> None: + await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) + abort_event.set() + + asyncio.create_task(delayed_abort()) + + headers = to_connect_headers(msg.request_headers) resp = await getattr(client, msg.method)( UnaryRequest( message=req, - headers=header, + headers=headers, timeout=msg.timeout_ms / 1000, + abort_event=abort_event, ), ) payloads.append(resp.message.payload) @@ -243,29 +258,116 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c response_trailers=to_pb_headers(resp.trailers), ), ) + elif msg.stream_type == config_pb2.STREAM_TYPE_CLIENT_STREAM: + abort_event = asyncio.Event() + headers = to_connect_headers(msg.request_headers) + + async def _reqs() -> AsyncGenerator[service_pb2.ClientStreamRequest]: + async for req in reqs: + if msg.request_delay_ms > 0: + await asyncio.sleep(msg.request_delay_ms / 1000) + yield req + + if msg.cancel.HasField("before_close_send"): + abort_event.set() + + if msg.cancel.HasField("after_close_send_ms"): + + async def delayed_abort() -> None: + await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) + abort_event.set() + + asyncio.create_task(delayed_abort()) + + async with getattr(client, msg.method)( + StreamRequest( + messages=_reqs(), headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event + ), + ) as resp: + async for message in resp.messages: + payloads.append(message.payload) + + return client_compat_pb2.ClientCompatResponse( + test_name=msg.test_name, + response=client_compat_pb2.ClientResponseResult( + payloads=payloads, + http_status_code=200, + response_headers=to_pb_headers(resp.headers), + response_trailers=to_pb_headers(resp.trailers), + ), + ) + elif msg.stream_type == config_pb2.STREAM_TYPE_SERVER_STREAM: + abort_event = asyncio.Event() + if msg.request_delay_ms > 0: + await asyncio.sleep(msg.request_delay_ms / 1000) + + headers = to_connect_headers(msg.request_headers) + + async with getattr(client, msg.method)( + StreamRequest( + messages=reqs, headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event + ), + ) as resp: + if msg.cancel.HasField("after_close_send_ms"): + + async def delayed_abort() -> None: + await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) + abort_event.set() + + asyncio.create_task(delayed_abort()) + + async for message in resp.messages: + payloads.append(message.payload) + if len(payloads) == msg.cancel.after_num_responses: + abort_event.set() + + return client_compat_pb2.ClientCompatResponse( + test_name=msg.test_name, + response=client_compat_pb2.ClientResponseResult( + payloads=payloads, + http_status_code=200, + response_headers=to_pb_headers(resp.headers), + response_trailers=to_pb_headers(resp.trailers), + ), + ) + elif ( - msg.stream_type == config_pb2.STREAM_TYPE_CLIENT_STREAM - or msg.stream_type == config_pb2.STREAM_TYPE_SERVER_STREAM - or msg.stream_type == config_pb2.STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM + msg.stream_type == config_pb2.STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM or msg.stream_type == config_pb2.STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM ): - header = Headers() - for h in msg.request_headers: - if key := header.get(h.name.lower()): - header[key] = f"{header[key]}, {', '.join(h.value)}" - else: - header[h.name.lower()] = ", ".join(h.value) + abort_event = asyncio.Event() - resp = await getattr(client, msg.method)( + async def _reqs() -> AsyncGenerator[service_pb2.ClientStreamRequest]: + async for req in reqs: + if msg.request_delay_ms > 0: + await asyncio.sleep(msg.request_delay_ms / 1000) + yield req + + headers = to_connect_headers(msg.request_headers) + + async with getattr(client, msg.method)( StreamRequest( - messages=reqs, - headers=header, - timeout=msg.timeout_ms / 1000, + messages=_reqs(), headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event ), - ) + ) as resp: + if msg.cancel.HasField("before_close_send"): + abort_event.set() - async for message in resp.messages: - payloads.append(message.payload) + if msg.cancel.HasField("after_close_send_ms"): + + async def delayed_abort() -> None: + await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) + abort_event.set() + + asyncio.create_task(delayed_abort()) + + if msg.cancel.HasField("after_num_responses") and msg.cancel.after_num_responses == 0: + abort_event.set() + + async for message in resp.messages: + payloads.append(message.payload) + if len(payloads) == msg.cancel.after_num_responses: + abort_event.set() return client_compat_pb2.ClientCompatResponse( test_name=msg.test_name, @@ -306,11 +408,6 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c if "--debug" in sys.argv: logging.debug("Debug mode enabled") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - tasks = [] - async def run_message(req: client_compat_pb2.ClientCompatRequest) -> None: """Run the message handler for a given request.""" try: @@ -325,16 +422,10 @@ async def run_message(req: client_compat_pb2.ClientCompatRequest) -> None: async def read_requests() -> None: """Read requests from standard input and process them asynchronously.""" + loop = asyncio.get_event_loop() while req := await loop.run_in_executor(None, read_request): - task = loop.create_task(run_message(req)) - tasks.append(task) - - loop.run_until_complete(read_requests()) + await asyncio.sleep(0.01) - pending_tasks = [t for t in tasks if not t.done()] - if pending_tasks: - logger.info(f"Waiting for {len(pending_tasks)} pending tasks to complete...") - loop.run_until_complete(asyncio.gather(*pending_tasks)) + loop.create_task(run_message(req)) - logger.info("All done") - loop.close() + asyncio.run(read_requests()) diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt new file mode 100644 index 0000000..0fb9dec --- /dev/null +++ b/conformance/run-testcase.txt @@ -0,0 +1,2 @@ +Connect Unexpected Responses/HTTPVersion:2/TLS:true/client-stream/multiple-responses +Connect Unexpected Responses/HTTPVersion:2/TLS:false/client-stream/ok-but-no-response diff --git a/src/connect/client.py b/src/connect/client.py index c63261d..b3393d0 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -3,7 +3,8 @@ These classes allow making unary calls to a specified URL with given request and response types. """ -from collections.abc import Awaitable, Callable +import contextlib +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any import httpcore @@ -227,7 +228,7 @@ def on_request_send(r: httpcore.Request) -> None: conn.on_request_send(on_request_send) - await conn.send(request.message, request.timeout) + await conn.send(request.message, request.timeout, abort_event=request.abort_event) response = await recieve_unary_response(conn=conn, t=output) return response @@ -267,9 +268,9 @@ def on_request_send(r: httpcore.Request) -> None: conn.on_request_send(on_request_send) - await conn.send(request.messages, request.timeout) + await conn.send(request.messages, request.timeout, request.abort_event) - response = await recieve_stream_response(conn, output, request.spec) + response = await recieve_stream_response(conn, output, request.spec, request.abort_event) return response stream_func = apply_interceptors(_stream_func, options.interceptors) @@ -299,49 +300,84 @@ async def call_unary(self, request: UnaryRequest[T_Request]) -> UnaryResponse[T_ """ return await self._call_unary(request) - async def call_server_stream(self, request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: - """Asynchronously calls a server streaming RPC (Remote Procedure Call) with the given request. + @contextlib.asynccontextmanager + async def call_server_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]: + """Initiate a server-streaming RPC call and returns an asynchronous generator that yields responses from the server. Args: - request (UnaryRequest[T_Request]): The request object containing the data to be sent to the server. + request (StreamRequest[T_Request]): The request object containing the + data to be sent to the server. - Returns: - UnaryResponse[T_Response]: The response object containing the data received from the server. + Yields: + StreamResponse[T_Response]: The response objects received from the server. - """ - return await self._call_stream(StreamType.ServerStream, request) + Raises: + Any exceptions that occur during the streaming process. + + Notes: + - This method ensures that the response stream is properly closed + after the generator is exhausted or an exception occurs. + - The type parameters `T_Request` and `T_Response` represent the + request and response types, respectively. - async def call_client_stream(self, request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: - """Asynchronously calls a client stream and yields responses. + """ + response = await self._call_stream(StreamType.ServerStream, request) + try: + yield response + finally: + await response.aclose() - This method sends a stream request to the client and asynchronously - iterates over the responses, yielding each response one by one. + @contextlib.asynccontextmanager + async def call_client_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]: + """Initiate a client-streaming RPC call and returns an asynchronous generator for streaming responses from the server. Args: - request (StreamRequest[T_Request]): The stream request to be sent. + request (StreamRequest[T_Request]): The request object containing the + client-streaming data to be sent to the server. Yields: - UnaryResponse[T_Response]: The response from the client stream. + StreamResponse[T_Response]: An asynchronous generator that yields + responses from the server. + + Raises: + Any exceptions raised during the streaming call. + + Notes: + - The `response.aclose()` method is called in the `finally` block to + ensure proper cleanup of the response stream. """ - return await self._call_stream(StreamType.ClientStream, request) + response = await self._call_stream(StreamType.ClientStream, request) + try: + yield response + finally: + await response.aclose() - async def call_bidi_stream(self, request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: - """Initiate a bidirectional streaming call. + @contextlib.asynccontextmanager + async def call_bidi_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]: + """Initiate a bidirectional streaming call with the server. - This method establishes a bidirectional stream between the client and the server, - allowing both to send and receive messages asynchronously. + This method sends a stream request to the server and returns an asynchronous + generator that yields stream responses from the server. The connection is + automatically closed when the generator is exhausted or an exception occurs. Args: - request (StreamRequest[T_Request]): The request object containing the stream - of messages to be sent to the server. + request (StreamRequest[T_Request]): The stream request object containing + the data to be sent to the server. - Returns: - StreamResponse[T_Response]: An asynchronous stream response object that - allows receiving messages from the server. + Yields: + StreamResponse[T_Response]: The stream response object received from the server. Raises: - Any exceptions raised during the streaming call will propagate to the caller. + Any exceptions raised during the streaming call. + + Notes: + Ensure to consume the generator properly to avoid resource leaks, as the + connection is closed in the `finally` block. """ - return await self._call_stream(StreamType.BiDiStream, request) + response = await self._call_stream(StreamType.BiDiStream, request) + try: + yield response + finally: + await response.aclose() diff --git a/src/connect/connect.py b/src/connect/connect.py index 3d96a42..5584561 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -1,7 +1,8 @@ """Defines the streaming handler connection interfaces and related utilities.""" import abc -from collections.abc import AsyncIterator, Callable, Mapping +import asyncio +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Mapping from enum import Enum from http import HTTPMethod from typing import Any, Protocol, cast @@ -12,7 +13,7 @@ from connect.error import ConnectError from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel -from connect.utils import aiterate, get_callable_attribute +from connect.utils import AsyncDataStream, aiterate, get_acallable_attribute, get_callable_attribute class StreamType(Enum): @@ -74,7 +75,6 @@ def __init__( """Initialize a new Request instance. Args: - messages (AsyncIterator[T]): An asynchronous iterator of messages. spec (Spec): The specification for the request. peer (Peer): The peer information. headers (Mapping[str, str]): The request headers. @@ -138,7 +138,7 @@ class StreamRequest[T](RequestCommon): """StreamRequest class represents a request that can handle streaming messages. Attributes: - messages (AsyncIterator[T]): An asynchronous iterator of messages. + messages (AsyncIterable[T]): An asynchronous iterable of messages. _spec (Spec): The specification for the request. _peer (Peer): The peer information. _headers (Headers): The request headers. @@ -146,38 +146,42 @@ class StreamRequest[T](RequestCommon): """ - _messages: AsyncIterator[T] + _messages: AsyncIterable[T] timeout: float | None + abort_event: asyncio.Event | None = None def __init__( self, - messages: AsyncIterator[T] | T, + messages: AsyncIterable[T] | T, spec: Spec | None = None, peer: Peer | None = None, headers: Headers | None = None, method: str | None = None, timeout: float | None = None, + abort_event: asyncio.Event | None = None, ) -> None: """Initialize a new Request instance. Args: - messages (AsyncIterator[T]): An asynchronous iterator of messages. + messages (AsyncIterable[T] | T): The request messages. spec (Spec): The specification for the request. peer (Peer): The peer information. headers (Mapping[str, str]): The request headers. method (str): The HTTP method used for the request. timeout (float): The timeout for the request. + abort_event (asyncio.Event): An event to signal request abortion. Returns: None """ super().__init__(spec, peer, headers, method) - self._messages = messages if isinstance(messages, AsyncIterator) else aiterate([messages]) + self._messages = messages if isinstance(messages, AsyncIterable) else aiterate([messages]) self.timeout = timeout + self.abort_event = abort_event @property - def messages(self) -> AsyncIterator[T]: + def messages(self) -> AsyncIterable[T]: """Return the request message.""" return self._messages @@ -196,6 +200,7 @@ class UnaryRequest[T](RequestCommon): _message: T timeout: float | None + abort_event: asyncio.Event | None = None def __init__( self, @@ -205,6 +210,7 @@ def __init__( headers: Headers | None = None, method: str | None = None, timeout: float | None = None, + abort_event: asyncio.Event | None = None, ) -> None: """Initialize a new Request instance. @@ -215,6 +221,7 @@ def __init__( headers (Mapping[str, str]): The request headers. method (str): The HTTP method used for the request. timeout (float): The timeout for the request. + abort_event (asyncio.Event): An event to signal request abortion. Returns: None @@ -223,6 +230,7 @@ def __init__( super().__init__(spec, peer, headers, method) self._message = message self.timeout = timeout + self.abort_event = abort_event @property def message(self) -> T: @@ -286,23 +294,29 @@ def message(self) -> T: class StreamResponse[T](ResponseCommon): """Response class for handling responses.""" - _messages: AsyncIterator[T] + _messages: AsyncIterable[T] def __init__( self, - messages: AsyncIterator[T] | T, + messages: AsyncIterable[T] | T, headers: Headers | None = None, trailers: Headers | None = None, ) -> None: """Initialize the response with a message.""" super().__init__(headers, trailers) - self._messages = messages if isinstance(messages, AsyncIterator) else aiterate([messages]) + self._messages = messages if isinstance(messages, AsyncIterable) else aiterate([messages]) @property - def messages(self) -> AsyncIterator[T]: + def messages(self) -> AsyncIterable[T]: """Return the response message.""" return self._messages + async def aclose(self) -> None: + """Asynchronously close the response stream.""" + aclose = get_acallable_attribute(self._messages, "aclose") + if aclose: + await aclose() + class UnaryHandlerConn(abc.ABC): """Abstract base class for a streaming handler connection. @@ -475,11 +489,11 @@ def request_headers(self) -> Headers: raise NotImplementedError() @abc.abstractmethod - async def send(self, messages: AsyncIterator[Any]) -> None: + async def send(self, messages: AsyncIterable[Any]) -> None: """Send a stream of messages asynchronously. Args: - messages (AsyncIterator[Any]): An asynchronous iterator that yields messages to be sent. + messages (AsyncIterable[Any]): The messages to be sent. Raises: NotImplementedError: This method should be implemented by subclasses. @@ -556,7 +570,7 @@ def request_headers(self) -> Headers: raise NotImplementedError() @abc.abstractmethod - async def send(self, message: Any, timeout: float | None) -> bytes: + async def send(self, message: Any, timeout: float | None, abort_event: asyncio.Event | None) -> bytes: """Send a message.""" raise NotImplementedError() @@ -594,7 +608,7 @@ def peer(self) -> Peer: raise NotImplementedError() @abc.abstractmethod - def receive(self, message: Any) -> AsyncIterator[Any]: + def receive(self, message: Any, abort_event: asyncio.Event | None) -> AsyncIterator[Any]: """Receives a message and processes it.""" raise NotImplementedError() @@ -605,7 +619,9 @@ def request_headers(self) -> Headers: raise NotImplementedError() @abc.abstractmethod - async def send(self, messages: AsyncIterator[Any], timeout: float | None) -> None: + async def send( + self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None + ) -> None: """Send a stream of messages.""" raise NotImplementedError() @@ -626,6 +642,11 @@ def on_request_send(self, fn: Callable[..., Any]) -> None: """Handle the request send event.""" raise NotImplementedError() + @abc.abstractmethod + async def aclose(self) -> None: + """Asynchronously close the connection.""" + raise NotImplementedError() + class ReceiveConn(Protocol): """A protocol that defines the methods required for receiving connections.""" @@ -763,36 +784,78 @@ async def recieve_unary_response[T](conn: UnaryClientConn, t: type[T]) -> UnaryR return UnaryResponse(message, conn.response_headers, conn.response_trailers) -async def recieve_stream_response[T](conn: StreamingClientConn, t: type[T], spec: Spec) -> StreamResponse[T]: - """Receive a stream response from a streaming client connection. +async def _receive_exactly_one[T](stream: AsyncIterator[T], aclose: Callable[[], Awaitable[None]]) -> T: + """Asynchronously receives exactly one item from an asynchronous iterator. + + This function ensures that the provided asynchronous iterator (`stream`) yields + exactly one item. If the iterator yields no items or more than one item, a + `ConnectError` is raised. The provided `aclose` callable is always invoked to + close the stream, regardless of success or failure. + + Type Parameters: + T: The type of the items in the asynchronous iterator. Args: - conn (StreamingClientConn): The streaming client connection. - t (type[T]): The type of the response to be received. - spec (Spec): The specification for the request. + stream (AsyncIterator[T]): The asynchronous iterator to consume. + aclose (Callable[[], Awaitable[None]]): A callable that closes the stream + when invoked. Returns: - StreamResponse[T]: The stream response containing the received data, response headers, and response trailers. - - """ - if spec.stream_type == StreamType.ClientStream: - count = 0 - single_message: T | None = None - async for message in conn.receive(t): - single_message = message - count += 1 + T: The single item yielded by the asynchronous iterator. - if single_message is None: - raise ConnectError("ClientStream should receive one message, but received none.", Code.UNIMPLEMENTED) + Raises: + ConnectError: If the iterator yields no items or more than one item. - if count > 1: + """ + try: + first = await stream.__anext__() + try: + await stream.__anext__() raise ConnectError( "ClientStream should only receive one message, but received multiple.", Code.UNIMPLEMENTED ) + except StopAsyncIteration: + return first + except StopAsyncIteration: + raise ConnectError("ClientStream should receive one message, but received none.", Code.UNIMPLEMENTED) from None + finally: + await aclose() + + +async def recieve_stream_response[T]( + conn: StreamingClientConn, t: type[T], spec: Spec, abort_event: asyncio.Event | None +) -> StreamResponse[T]: + """Handle receiving a stream response from a streaming client connection. + + Args: + conn (StreamingClientConn): The streaming client connection used to receive the stream. + t (type[T]): The type of the messages expected in the stream. + spec (Spec): The specification of the stream, including its type. + abort_event (asyncio.Event | None): An optional event to signal abortion of the stream. + + Returns: + StreamResponse[T]: A response object containing the received stream, response headers, + and response trailers. - return StreamResponse(aiterate([single_message]), conn.response_headers, conn.response_trailers) + Raises: + Any exceptions raised during the reception of the stream or processing of the messages. + + Notes: + - If the stream type is `StreamType.ClientStream`, it expects exactly one message + and wraps it in a single-message stream. + - For other stream types, it directly returns the received stream. + + """ + receive_stream = AsyncDataStream[T](conn.receive(t, abort_event), conn.aclose) + + if spec.stream_type == StreamType.ClientStream: + single_message = await _receive_exactly_one(receive_stream.__aiter__(), receive_stream.aclose) + + return StreamResponse( + AsyncDataStream[T](aiterate([single_message])), conn.response_headers, conn.response_trailers + ) else: - return StreamResponse(conn.receive(t), conn.response_headers, conn.response_trailers) + return StreamResponse(receive_stream, conn.response_headers, conn.response_trailers) async def receive_unary_message[T](conn: ReceiveConn, t: type[T]) -> T: diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 085beae..ee3e2d2 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -1,12 +1,12 @@ """Provides classes and functions for handling protocol connections.""" +import asyncio import base64 import contextlib import json from collections.abc import ( AsyncIterable, AsyncIterator, - Awaitable, Callable, Mapping, ) @@ -55,7 +55,12 @@ ) from connect.request import Request from connect.session import AsyncClientSession -from connect.utils import AsyncByteStream, AsyncIteratorByteStream, aiterate, map_httpcore_exceptions +from connect.utils import ( + AsyncByteStream, + aiterate, + get_acallable_attribute, + map_httpcore_exceptions, +) from connect.version import __version__ from connect.writer import ServerResponseWriter @@ -423,7 +428,7 @@ class ConnectUnaryUnmarshaler: codec: Codec | None read_max_bytes: int compression: Compression | None - stream: AsyncIteratorByteStream | None + stream: AsyncIterable[bytes] | None def __init__( self, @@ -444,7 +449,7 @@ def __init__( self.codec = codec self.read_max_bytes = read_max_bytes self.compression = compression - self.stream = AsyncIteratorByteStream(stream) if stream else None + self.stream = stream async def unmarshal(self, message: Any) -> Any: """Asynchronously unmarshals a given message using the provided unmarshal function and codec. @@ -512,10 +517,21 @@ async def unmarshal_func(self, message: Any, func: Callable[[bytes, Any], Any]) Code.INVALID_ARGUMENT, ) from e finally: - await self.stream.aclose() + await self.aclose() return obj + async def aclose(self) -> None: + """Asynchronously close the stream if it is set. + + This method is intended to be called when the stream is no longer needed + to release any associated resources. + + """ + aclose = get_acallable_attribute(self.stream, "aclose") + if aclose: + await aclose() + class ConnectUnaryMarshaler: """ConnectUnaryMarshaler is responsible for serializing and optionally compressing messages. @@ -1099,42 +1115,48 @@ def _write_with_get(self, url: URL) -> None: self.url = url -class ResponseAsyncByteStream(AsyncByteStream): +class HTTPCoreResponseAsyncByteStream(AsyncByteStream): """An asynchronous byte stream for reading and writing byte chunks.""" aiterator: AsyncIterable[bytes] | None - aclose_func: Callable[..., Awaitable[None]] | None + _closed: bool def __init__( self, aiterator: AsyncIterable[bytes] | None = None, - aclose_func: Callable[..., Awaitable[None]] | None = None, ) -> None: """Initialize the protocol connect instance. Args: aiterator (AsyncIterable[bytes] | None): An optional asynchronous iterable of bytes. - aclose_func (Callable[..., Awaitable[None]] | None): An optional asynchronous close function. Returns: None """ self.aiterator = aiterator - self.aclose_func = aclose_func + self._closed = False async def __aiter__(self) -> AsyncIterator[bytes]: """Asynchronous iterator method to read byte chunks from the stream.""" - if self.aiterator is not None: - with map_httpcore_exceptions(): - async for chunk in self.aiterator: - yield chunk + if self.aiterator: + try: + with map_httpcore_exceptions(): + async for chunk in self.aiterator: + yield chunk + except BaseException as exc: + await self.aclose() + raise exc async def aclose(self) -> None: """Asynchronously close the stream.""" - if self.aclose_func: + if not self._closed and self.aiterator: + aclose = get_acallable_attribute(self.aiterator, "aclose") + if not aclose: + return + with map_httpcore_exceptions(): - await self.aclose_func() + await aclose() class ConnectStreamingMarshaler: @@ -1168,11 +1190,11 @@ def __init__( self.send_max_bytes = send_max_bytes self.compression = compression - async def marshal(self, messages: AsyncIterator[Any]) -> AsyncIterator[bytes]: + async def marshal(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: """Asynchronously marshals and compresses messages from an asynchronous iterator. Args: - messages (AsyncIterator[Any]): An asynchronous iterator of messages to be marshaled. + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled. Yields: AsyncIterator[bytes]: An asynchronous iterator of marshaled and optionally compressed messages in bytes. @@ -1268,7 +1290,7 @@ class ConnectStreamingUnmarshaler: Attributes: codec (Codec): The codec used for unmarshaling data. compression (Compression | None): The compression method used, if any. - stream (AsyncIteratorByteStream | None): The asynchronous byte stream to read data from. + stream (AsyncIterable[bytes] | None): The asynchronous byte stream to read from. buffer (bytes): The buffer to store incoming data chunks. """ @@ -1276,7 +1298,7 @@ class ConnectStreamingUnmarshaler: codec: Codec | None read_max_bytes: int compression: Compression | None - stream: AsyncIteratorByteStream | None + stream: AsyncIterable[bytes] | None buffer: bytes _end_stream_error: ConnectError | None _trailers: Headers @@ -1300,7 +1322,7 @@ def __init__( self.codec = codec self.read_max_bytes = read_max_bytes self.compression = compression - self.stream = AsyncIteratorByteStream(stream) if stream else None + self.stream = stream self.buffer = b"" self._end_stream_error = None self._trailers = Headers() @@ -1325,58 +1347,56 @@ async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: if self.codec is None: raise ConnectError("codec is not set", Code.INTERNAL) - try: - async for chunk in self.stream: - self.buffer += chunk + async for chunk in self.stream: + self.buffer += chunk + + while True: + env, data_len = Envelope.decode(self.buffer) + if env is None: + break + + if self.read_max_bytes > 0 and data_len > self.read_max_bytes: + raise ConnectError( + f"message size {data_len} is larger than configured readMaxBytes {self.read_max_bytes}", + Code.RESOURCE_EXHAUSTED, + ) - while True: - env, data_len = Envelope.decode(self.buffer) - if env is None: - break + self.buffer = self.buffer[5 + data_len :] - if self.read_max_bytes > 0 and data_len > self.read_max_bytes: + if env.is_set(EnvelopeFlags.compressed): + if not self.compression: raise ConnectError( - f"message size {data_len} is larger than configured readMaxBytes {self.read_max_bytes}", - Code.RESOURCE_EXHAUSTED, + "protocol error: sent compressed message without compression support", Code.INTERNAL ) - self.buffer = self.buffer[5 + data_len :] - - if env.is_set(EnvelopeFlags.compressed): - if not self.compression: - raise ConnectError( - "protocol error: sent compressed message without compression support", Code.INTERNAL - ) - - env.data = self.compression.decompress(env.data, self.read_max_bytes) - - if env.is_set(EnvelopeFlags.end_stream): - error, trailers = end_stream_from_bytes(env.data) - self._end_stream_error = error - self._trailers = trailers - end = True - obj = None - else: - try: - obj = self.codec.unmarshal(env.data, message) - except Exception as e: - raise ConnectError( - f"unmarshal message: {str(e)}", - Code.INVALID_ARGUMENT, - ) from e - - end = False - - yield obj, end - - if len(self.buffer) > 0: - header = Envelope.decode_header(self.buffer) - if header: - message = f"protocol error: promised {header[1]} bytes in enveloped message, got {len(self.buffer) - 5} bytes" - raise ConnectError(message, Code.INVALID_ARGUMENT) + env.data = self.compression.decompress(env.data, self.read_max_bytes) + + if env.is_set(EnvelopeFlags.end_stream): + error, trailers = end_stream_from_bytes(env.data) + self._end_stream_error = error + self._trailers = trailers + end = True + obj = None + else: + try: + obj = self.codec.unmarshal(env.data, message) + except Exception as e: + raise ConnectError( + f"unmarshal message: {str(e)}", + Code.INVALID_ARGUMENT, + ) from e - finally: - await self.stream.aclose() + end = False + + yield obj, end + + if len(self.buffer) > 0: + header = Envelope.decode_header(self.buffer) + if header: + message = ( + f"protocol error: promised {header[1]} bytes in enveloped message, got {len(self.buffer) - 5} bytes" + ) + raise ConnectError(message, Code.INVALID_ARGUMENT) @property def trailers(self) -> Headers: @@ -1401,6 +1421,20 @@ def end_stream_error(self) -> ConnectError | None: """ return self._end_stream_error + async def aclose(self) -> None: + """Asynchronously closes the stream if it has an `aclose` method. + + This method checks if the `self.stream` object has an asynchronous + `aclose` method. If the method exists, it is invoked to close the stream. + + Returns: + None + + """ + aclose = get_acallable_attribute(self.stream, "aclose") + if aclose: + await aclose() + class ConnectStreamingHandlerConn(StreamingHandlerConn): """ConnectStreamingHandlerConn is a class that handles streaming connections for the Connect protocol. @@ -1506,7 +1540,7 @@ def request_headers(self) -> Headers: """ return self._request_headers - async def send(self, messages: AsyncIterator[Any]) -> None: + async def send(self, messages: AsyncIterable[Any]) -> None: """Send a stream of messages asynchronously. This method marshals the provided messages and sends them using the writer. @@ -1514,7 +1548,7 @@ async def send(self, messages: AsyncIterator[Any]) -> None: converts it to a JSON object, and sends it as the final message in the stream. Args: - messages (AsyncIterator[Any]): An asynchronous iterator of messages to be sent. + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be sent. Returns: None @@ -1719,18 +1753,26 @@ def on_request_send(self, fn: EventHook) -> None: """ self._event_hooks["request"].append(fn) - async def receive(self, message: Any) -> AsyncIterator[Any]: + async def receive(self, message: Any, abort_event: asyncio.Event | None = None) -> AsyncIterator[Any]: """Asynchronously receives and processes a message. Args: message (Any): The message to be processed. + abort_event (asyncio.Event | None): Event to signal abortion of the operation. Yields: Any: Objects obtained from unmarshaling the message. + Raises: + ConnectError: If stream is malformed or aborted. + """ end_stream_received = False + async for obj, end in self.unmarshaler.unmarshal(message): + if abort_event and abort_event.is_set(): + raise ConnectError("receive operation aborted", Code.CANCELED) + if end: if end_stream_received: raise ConnectError("received extra end stream message", Code.INVALID_ARGUMENT) @@ -1756,24 +1798,33 @@ async def receive(self, message: Any) -> AsyncIterator[Any]: if not end_stream_received: raise ConnectError("missing end stream message", Code.INVALID_ARGUMENT) - async def send(self, messages: AsyncIterator[Any], timeout: float | None) -> None: - """Send a series of messages asynchronously. - - This method marshals the provided messages, constructs an HTTP POST request, - and sends it using the httpcore library. It also triggers any registered - request and response hooks, and validates the response. + async def send( + self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None + ) -> None: + """Send an asynchronous HTTP POST request with the given messages and handle the response. Args: - messages (AsyncIterator[Any]): An asynchronous iterator of messages to be sent. - timeout (float | None): The timeout for the request in seconds. - - Returns: - None + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be sent. + timeout (float | None): Optional timeout value in seconds for the request. If provided, + it sets the read timeout for the request. + abort_event (asyncio.Event | None): Optional asyncio event that, if set, will abort the request. Raises: - Exception: If there is an error during the request or response handling. + ConnectError: If the request is aborted or if there is an error during the request. + + Hooks: + - Executes hooks registered in `self._event_hooks["request"]` before sending the request. + - Executes hooks registered in `self._event_hooks["response"]` after receiving the response. + + Notes: + - If `abort_event` is provided and set during the request, the request will be canceled, + and a `ConnectError` with code `Code.CANCELED` will be raised. + - The response stream is unmarshaled and validated after the request is completed. """ + if abort_event and abort_event.is_set(): + raise ConnectError("request aborted", Code.CANCELED) + extensions = {} if timeout: extensions["timeout"] = {"read": timeout} @@ -1802,14 +1853,32 @@ async def send(self, messages: AsyncIterator[Any], timeout: float | None) -> Non hook(request) with map_httpcore_exceptions(): - response = await self.session.pool.handle_async_request(request) + if not abort_event: + response = await self.session.pool.handle_async_request(request) + else: + request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) + abort_task = asyncio.create_task(abort_event.wait()) + + done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) + + if abort_task in done: + request_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await request_task + + raise ConnectError("request aborted", Code.CANCELED) + + abort_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await abort_task + + response = await request_task for hook in self._event_hooks["response"]: hook(response) - self.unmarshaler.stream = AsyncIteratorByteStream( - ResponseAsyncByteStream(aiterator=response.aiter_stream(), aclose_func=response.aclose) - ) + assert isinstance(response.stream, AsyncIterable) + self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(aiterator=response.stream) await self._validate_response(response) @@ -1852,6 +1921,15 @@ async def _validate_response(self, response: httpcore.Response) -> None: self.unmarshaler.compression = get_compresion_from_name(compression, self.compressions) self._response_headers.update(response_headers) + async def aclose(self) -> None: + """Asynchronously closes the connection by invoking the `aclose` method of the unmarshaler. + + Returns: + None + + """ + await self.unmarshaler.aclose() + class ConnectUnaryClientConn(UnaryClientConn): """A client connection for unary RPCs using the Connect protocol. @@ -1983,20 +2061,27 @@ def on_request_send(self, fn: EventHook) -> None: """ self._event_hooks["request"].append(fn) - async def send(self, message: Any, timeout: float | None) -> bytes: - """Send a message asynchronously and returns the marshaled data. + async def send(self, message: Any, timeout: float | None, abort_event: asyncio.Event | None) -> bytes: + """Send a message asynchronously using the specified HTTP method and handles the response. Args: - message (Any): The message to be sent. - timeout (float | None): The timeout for the request in seconds. + message (Any): The message to be sent, which will be marshaled before sending. + timeout (float | None): The timeout for the request in seconds. If provided, it will be + included in the request headers and extensions. + abort_event (asyncio.Event | None): An optional asyncio event that can be used to abort + the request. If the event is set, the request will be canceled. Returns: - bytes: The marshaled data of the message. + bytes: The marshaled data of the message that was sent. Raises: - Exception: If the response validation fails. + ConnectError: If the request is aborted or if there are issues during the request/response + lifecycle. """ + if abort_event and abort_event.is_set(): + raise ConnectError("request aborted", Code.CANCELED) + extensions = {} if timeout: extensions["timeout"] = {"read": timeout} @@ -2046,14 +2131,32 @@ async def send(self, message: Any, timeout: float | None) -> bytes: hook(request) with map_httpcore_exceptions(): - response = await self.session.pool.handle_async_request(request=request) + if not abort_event: + response = await self.session.pool.handle_async_request(request=request) + else: + request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) + abort_task = asyncio.create_task(abort_event.wait()) + + done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) + + if abort_task in done: + request_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await request_task + + raise ConnectError("request aborted", Code.CANCELED) + + abort_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await abort_task + + response = await request_task for hook in self._event_hooks["response"]: hook(response) - self.unmarshaler.stream = AsyncIteratorByteStream( - ResponseAsyncByteStream(response.aiter_stream(), response.aclose) - ) + assert isinstance(response.stream, AsyncIterable) + self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(response.stream) await self._validate_response(response) diff --git a/src/connect/utils.py b/src/connect/utils.py index 9ef0d80..1bab7b7 100644 --- a/src/connect/utils.py +++ b/src/connect/utils.py @@ -87,10 +87,34 @@ def get_callable_attribute(obj: object, attr: str) -> typing.Callable[..., typin typing.Callable[..., typing.Any] | None: The callable attribute if it exists and is callable, otherwise None. """ - if hasattr(obj, attr) and callable(getattr(obj, attr)): - return getattr(obj, attr) + try: + attr_value = getattr(obj, attr) + if callable(attr_value): + return attr_value + return None + except AttributeError: + return None + + +def get_acallable_attribute(obj: object, attr: str) -> typing.Callable[..., typing.Awaitable[typing.Any]] | None: + """Retrieve an attribute from an object if it is both callable and asynchronous. + + Args: + obj (object): The object from which to retrieve the attribute. + attr (str): The name of the attribute to retrieve. - return None + Returns: + typing.Callable[..., typing.Awaitable[typing.Any]] | None: + The attribute if it is callable and asynchronous, otherwise None. + + """ + try: + attr_value = getattr(obj, attr) + if callable(attr_value) and is_async_callable(attr_value): + return attr_value + return None + except AttributeError: + return None def get_route_path(scope: Scope) -> str: @@ -192,60 +216,88 @@ def __init__(self) -> None: super().__init__("Stream has already been consumed.") -class AsyncIteratorByteStream: - """An asynchronous iterator for byte streams. +class AsyncDataStream[T]: + """AsyncDataStream is a generic class that provides an asynchronous iterable interface for streaming data. - This class wraps an asynchronous iterable of bytes and provides an - asynchronous iterator interface. It ensures that the stream is only - consumed once and provides a method to close the stream if it supports - asynchronous closing. + It ensures that the stream is consumed only once and provides a mechanism for resource cleanup. + Type Parameters: + T: The type of elements in the asynchronous stream. Attributes: - _stream (typing.AsyncIterable[bytes]): The asynchronous iterable byte stream. - _is_stream_consumed (bool): A flag indicating whether the stream has been consumed. + _stream (typing.AsyncIterable[T]): The asynchronous iterable representing the stream of data. + _is_stream_consumed (bool): A flag indicating whether the stream has already been consumed. + aclose_func (Callable[..., Awaitable[None]] | None): An optional asynchronous callable for closing resources. + + Methods: + __init__(stream: typing.AsyncIterable[T], aclose_func: Callable[..., Awaitable[None]] | None = None) -> None: + Initializes the AsyncDataStream instance with the given stream and optional close function. + __aiter__() -> typing.AsyncIterator[T]: + Asynchronously iterates over the elements of the stream. Ensures the stream is consumed only once + and handles cleanup in case of exceptions. + aclose() -> None: + Asynchronously closes resources if an asynchronous close function is provided. """ - def __init__(self, stream: typing.AsyncIterable[bytes]) -> None: - """Initialize the utility with an asynchronous byte stream. + _stream: typing.AsyncIterable[T] + _is_stream_consumed: bool + aclose_func: Callable[..., Awaitable[None]] | None + + def __init__( + self, stream: typing.AsyncIterable[T], aclose_func: Callable[..., Awaitable[None]] | None = None + ) -> None: + """Initialize an instance of the class. Args: - stream (typing.AsyncIterable[bytes]): An asynchronous iterable that yields bytes. + stream (typing.AsyncIterable[T]): An asynchronous iterable representing the stream of data. + aclose_func (Callable[..., Awaitable[None]] | None, optional): + A callable function that is awaited to close the stream. Defaults to None. """ self._stream = stream self._is_stream_consumed = False + self.aclose_func = aclose_func - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - """Asynchronously iterates over the stream and yields parts of it. + async def __aiter__(self) -> typing.AsyncIterator[T]: + """Asynchronously iterates over the elements of the stream. + + This method allows the object to be used as an asynchronous iterator. + It ensures that the stream is not consumed multiple times and properly + handles cleanup in case of exceptions. Yields: - bytes: Parts of the stream. + T: The next element in the asynchronous stream. Raises: StreamConsumedError: If the stream has already been consumed. + BaseException: Propagates any exception raised during iteration + after ensuring the stream is closed. """ if self._is_stream_consumed: raise StreamConsumedError() self._is_stream_consumed = True - async for part in self._stream: - yield part + try: + async for part in self._stream: + yield part + except BaseException as exc: + await self.aclose() + raise exc async def aclose(self) -> None: - """Asynchronously closes the stream if it has an `aclose` method. + """Asynchronously closes resources if an asynchronous close function is provided. - This method checks if the `_stream` attribute has an `aclose` method and - calls it asynchronously to close the stream. If the `_stream` does not - have an `aclose` method, this method does nothing. + This method checks if an `aclose_func` is defined. If it is, the function + is awaited to perform any necessary cleanup or resource deallocation. Returns: None """ - if isinstance(self._stream, AsyncByteStream): - await self._stream.aclose() + if self.aclose_func: + await self.aclose_func() + return async def aiterate[T](iterable: typing.Iterable[T]) -> typing.AsyncIterator[T]: diff --git a/tests/test_streaming_connect_client.py b/tests/test_streaming_connect_client.py index 0f4b3d3..365c9dd 100644 --- a/tests/test_streaming_connect_client.py +++ b/tests/test_streaming_connect_client.py @@ -69,11 +69,11 @@ async def test_server_streaming(hypercorn_server: ServerConfig) -> None: client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] - async for message in response.messages: - assert message.name in want - want.remove(message.name) + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] + async for message in response.messages: + assert message.name in want + want.remove(message.name) async def server_streaming_end_stream_error(scope: Scope, receive: Receive, send: Send) -> None: @@ -120,18 +120,18 @@ async def test_server_streaming_end_stream_error(hypercorn_server: ServerConfig) client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] - with pytest.raises(ConnectError) as excinfo: - async for message in response.messages: - assert message.name in want - want.remove(message.name) + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] + with pytest.raises(ConnectError) as excinfo: + async for message in response.messages: + assert message.name in want + want.remove(message.name) - assert excinfo.value.code == Code.UNAVAILABLE - assert excinfo.value.metadata["acme-operation-cost"] == "237" - assert excinfo.value.raw_message == "" - assert len(excinfo.value.details) == 0 - assert excinfo.value.wire_error is True + assert excinfo.value.code == Code.UNAVAILABLE + assert excinfo.value.metadata["acme-operation-cost"] == "237" + assert excinfo.value.raw_message == "" + assert len(excinfo.value.details) == 0 + assert excinfo.value.wire_error is True async def server_streaming_received_message_after_end_stream(scope: Scope, receive: Receive, send: Send) -> None: @@ -183,16 +183,16 @@ async def test_server_streaming_received_message_after_end_stream(hypercorn_serv client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] - with pytest.raises(ConnectError) as excinfo: - async for message in response.messages: - assert message.name in want - want.remove(message.name) + with pytest.raises(ConnectError) as excinfo: + async for message in response.messages: + assert message.name in want + want.remove(message.name) - assert excinfo.value.code == Code.INVALID_ARGUMENT - assert excinfo.value.raw_message == "received message after end stream" + assert excinfo.value.code == Code.INVALID_ARGUMENT + assert excinfo.value.raw_message == "received message after end stream" async def server_streaming_received_extra_end_stream(scope: Scope, receive: Receive, send: Send) -> None: @@ -248,16 +248,16 @@ async def test_server_streaming_received_extra_end_stream(hypercorn_server: Serv client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] - with pytest.raises(ConnectError) as excinfo: - async for message in response.messages: - assert message.name in want - want.remove(message.name) + with pytest.raises(ConnectError) as excinfo: + async for message in response.messages: + assert message.name in want + want.remove(message.name) - assert excinfo.value.code == Code.INVALID_ARGUMENT - assert excinfo.value.raw_message == "received extra end stream message" + assert excinfo.value.code == Code.INVALID_ARGUMENT + assert excinfo.value.raw_message == "received extra end stream message" async def server_streaming_not_received_end_stream(scope: Scope, receive: Receive, send: Send) -> None: @@ -299,16 +299,16 @@ async def test_server_streaming_not_received_end_stream(hypercorn_server: Server client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] - with pytest.raises(ConnectError) as excinfo: - async for message in response.messages: - assert message.name in want - want.remove(message.name) + with pytest.raises(ConnectError) as excinfo: + async for message in response.messages: + assert message.name in want + want.remove(message.name) - assert excinfo.value.code == Code.INVALID_ARGUMENT - assert excinfo.value.raw_message == "missing end stream message" + assert excinfo.value.code == Code.INVALID_ARGUMENT + assert excinfo.value.raw_message == "missing end stream message" async def server_streaming_response_envelope_message_compression(scope: Scope, receive: Receive, send: Send) -> None: @@ -356,11 +356,11 @@ async def test_server_streaming_response_envelope_message_compression(hypercorn_ client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] - async for message in response.messages: - assert message.name in want - want.remove(message.name) + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] + async for message in response.messages: + assert message.name in want + want.remove(message.name) async def server_streaming_request_envelope_message_compression(scope: Scope, receive: Receive, send: Send) -> None: @@ -421,11 +421,11 @@ async def test_server_streaming_request_envelope_message_compression(hypercorn_s ) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] - async for message in response.messages: - assert message.name in want - want.remove(message.name) + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] + async for message in response.messages: + assert message.name in want + want.remove(message.name) @pytest.mark.asyncio() @@ -481,14 +481,13 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: ping_request = StreamRequest(messages=PingRequest(name="test")) - await client.call_server_stream(ping_request) + async with client.call_server_stream(ping_request): + assert len(ephemeral_files) == 2 + for i, ephemeral_file in enumerate(reversed(ephemeral_files)): + ephemeral_file.seek(0) + assert ephemeral_file.read() == f"interceptor: {i + 1}".encode() - assert len(ephemeral_files) == 2 - for i, ephemeral_file in enumerate(reversed(ephemeral_files)): - ephemeral_file.seek(0) - assert ephemeral_file.read() == f"interceptor: {i + 1}".encode() - - ephemeral_file.close() + ephemeral_file.close() async def server_streaming_not_httpstatus_200(scope: Scope, receive: Receive, send: Send) -> None: @@ -520,12 +519,11 @@ async def test_server_streaming_not_httpstatus_200(hypercorn_server: ServerConfi ping_request = StreamRequest(messages=PingRequest(name="Bob")) with pytest.raises(ConnectError) as excinfo: - await client.call_server_stream(ping_request) - - assert excinfo.value.code == Code.UNAVAILABLE - assert len(excinfo.value.details) == 0 - assert excinfo.value.wire_error is False - assert excinfo.value.metadata == {} + async with client.call_server_stream(ping_request): + assert excinfo.value.code == Code.UNAVAILABLE + assert len(excinfo.value.details) == 0 + assert excinfo.value.wire_error is False + assert excinfo.value.metadata == {} async def client_streaming(scope: Scope, receive: Receive, send: Send) -> None: @@ -582,11 +580,11 @@ async def iterator() -> AsyncIterator[PingRequest]: client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=iterator()) - response = await client.call_client_stream(ping_request) - want = ["I'm fine."] - async for message in response.messages: - assert message.name in want - want.remove(message.name) + async with client.call_client_stream(ping_request) as response: + want = ["I'm fine."] + async for message in response.messages: + assert message.name in want + want.remove(message.name) @pytest.mark.asyncio() @@ -645,11 +643,10 @@ async def iterator() -> AsyncIterator[PingRequest]: ping_request = StreamRequest(messages=iterator()) - await client.call_client_stream(ping_request) - - assert len(ephemeral_files) == 2 - for i, ephemeral_file in enumerate(reversed(ephemeral_files)): - ephemeral_file.seek(0) - assert ephemeral_file.read() == f"interceptor: {i + 1}".encode() + async with client.call_client_stream(ping_request): + assert len(ephemeral_files) == 2 + for i, ephemeral_file in enumerate(reversed(ephemeral_files)): + ephemeral_file.seek(0) + assert ephemeral_file.read() == f"interceptor: {i + 1}".encode() - ephemeral_file.close() + ephemeral_file.close()