From 77e63cde1aa118ef3d0289de070336d27dd76af0 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Thu, 10 Apr 2025 01:09:42 +0900 Subject: [PATCH 01/16] conformance: fix client cancellation --- conformance/client_runner.py | 130 +++++++++++++++++++++++++++---- conformance/run-testcase-tmp.txt | 6 ++ conformance/run-testcase.txt | 1 + src/connect/client.py | 6 +- src/connect/connect.py | 23 ++++-- src/connect/protocol_connect.py | 61 +++++++++++++-- 6 files changed, 197 insertions(+), 30 deletions(-) create mode 100644 conformance/run-testcase-tmp.txt create mode 100644 conformance/run-testcase.txt diff --git a/conformance/client_runner.py b/conformance/client_runner.py index ea04bbb..0e85c34 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -6,7 +6,6 @@ import ssl import struct import sys -import time import traceback from collections.abc import AsyncGenerator from typing import Any @@ -204,9 +203,6 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c url = f"{proto}://{msg.host}:{msg.port}" - if msg.request_delay_ms > 0: - time.sleep(msg.request_delay_ms / 1000.0) - async with AsyncClientSession(http1=http1, http2=http2, ssl_context=ssl_context) as session: payloads = [] try: @@ -216,8 +212,20 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c client = service_connect.ConformanceServiceClient(base_url=url, session=session, options=options) if msg.stream_type == config_pb2.STREAM_TYPE_UNARY: + if msg.request_delay_ms > 0: + await asyncio.sleep(msg.request_delay_ms / 1000.0) + + abort_event = asyncio.Event() req = await anext(reqs) + if msg.HasField("cancel") and msg.cancel.HasField("after_close_send_ms"): + + async def delayed_abort() -> None: + await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) + abort_event.set() + + asyncio.create_task(delayed_abort()) + header = Headers() for h in msg.request_headers: if key := header.get(h.name.lower()): @@ -230,6 +238,7 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c message=req, headers=header, timeout=msg.timeout_ms / 1000, + abort_event=abort_event, ), ) payloads.append(resp.message.payload) @@ -243,12 +252,97 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c response_trailers=to_pb_headers(resp.trailers), ), ) + elif msg.stream_type == config_pb2.STREAM_TYPE_CLIENT_STREAM: + abort_event = asyncio.Event() + header = Headers() + for h in msg.request_headers: + if key := header.get(h.name.lower()): + header[key] = f"{header[key]}, {', '.join(h.value)}" + else: + header[h.name.lower()] = ", ".join(h.value) + + async def _reqs() -> AsyncGenerator[service_pb2.ClientStreamRequest]: + async for req in reqs: + if msg.request_delay_ms > 0: + await asyncio.sleep(msg.request_delay_ms / 1000.0) + yield req + + if msg.HasField("cancel") and msg.cancel.HasField("before_close_send"): + abort_event.set() + elif msg.HasField("cancel") and msg.cancel.HasField("after_close_send_ms"): + + async def delayed_abort() -> None: + await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) + abort_event.set() + + asyncio.create_task(delayed_abort()) + + resp = await getattr(client, msg.method)( + StreamRequest( + messages=_reqs(), headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event + ), + ) + + async for message in resp.messages: + payloads.append(message.payload) + + return client_compat_pb2.ClientCompatResponse( + test_name=msg.test_name, + response=client_compat_pb2.ClientResponseResult( + payloads=payloads, + http_status_code=200, + response_headers=to_pb_headers(resp.headers), + response_trailers=to_pb_headers(resp.trailers), + ), + ) + elif msg.stream_type == config_pb2.STREAM_TYPE_SERVER_STREAM: + abort_event = asyncio.Event() + if msg.request_delay_ms > 0: + await asyncio.sleep(msg.request_delay_ms / 1000.0) + + header = Headers() + for h in msg.request_headers: + if key := header.get(h.name.lower()): + header[key] = f"{header[key]}, {', '.join(h.value)}" + else: + header[h.name.lower()] = ", ".join(h.value) + + resp = await getattr(client, msg.method)( + StreamRequest( + messages=reqs, headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event + ), + ) + + if msg.HasField("cancel") and msg.cancel.HasField("after_close_send_ms"): + + async def delayed_abort() -> None: + await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) + abort_event.set() + + asyncio.create_task(delayed_abort()) + + async for message in resp.messages: + payloads.append(message.payload) + if len(payloads) == msg.cancel.after_num_responses: + abort_event.set() + + return client_compat_pb2.ClientCompatResponse( + test_name=msg.test_name, + response=client_compat_pb2.ClientResponseResult( + payloads=payloads, + http_status_code=200, + response_headers=to_pb_headers(resp.headers), + response_trailers=to_pb_headers(resp.trailers), + ), + ) + elif ( - msg.stream_type == config_pb2.STREAM_TYPE_CLIENT_STREAM - or msg.stream_type == config_pb2.STREAM_TYPE_SERVER_STREAM - or msg.stream_type == config_pb2.STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM + msg.stream_type == config_pb2.STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM or msg.stream_type == config_pb2.STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM ): + if msg.request_delay_ms > 0: + await asyncio.sleep(msg.request_delay_ms / 1000.0) + header = Headers() for h in msg.request_headers: if key := header.get(h.name.lower()): @@ -305,8 +399,12 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c if __name__ == "__main__": if "--debug" in sys.argv: logging.debug("Debug mode enabled") + import tracemalloc + + tracemalloc.start() loop = asyncio.new_event_loop() + loop.set_debug(True) asyncio.set_event_loop(loop) tasks = [] @@ -321,6 +419,8 @@ async def run_message(req: client_compat_pb2.ClientCompatRequest) -> None: error=client_compat_pb2.ClientErrorResult(message="".join(traceback.format_exception(e))), ) + await asyncio.sleep(3) + write_response(resp) async def read_requests() -> None: @@ -329,12 +429,10 @@ async def read_requests() -> None: task = loop.create_task(run_message(req)) tasks.append(task) - loop.run_until_complete(read_requests()) - - pending_tasks = [t for t in tasks if not t.done()] - if pending_tasks: - logger.info(f"Waiting for {len(pending_tasks)} pending tasks to complete...") - loop.run_until_complete(asyncio.gather(*pending_tasks)) - - logger.info("All done") - loop.close() + try: + loop.run_until_complete(read_requests()) + finally: + pending = asyncio.all_tasks(loop) + if pending: + loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + loop.close() diff --git a/conformance/run-testcase-tmp.txt b/conformance/run-testcase-tmp.txt new file mode 100644 index 0000000..bcd9ca2 --- /dev/null +++ b/conformance/run-testcase-tmp.txt @@ -0,0 +1,6 @@ +Client Cancellation/**/unary/cancel-after-close-send +Client Cancellation/**/client-stream/cancel-before-close-send +Client Cancellation/**/client-stream/cancel-after-close-send +Client Cancellation/**/server-stream/cancel-after-close-send +Client Cancellation/**/server-stream/cancel-after-responses +Client Cancellation/HTTPVersion:2/Protocol:PROTOCOL_CONNECT/Codec:CODEC_PROTO/Compression:COMPRESSION_GZIP/TLS:false/bidi-stream/full-duplex/cancel-after-responses diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt new file mode 100644 index 0000000..87b39f7 --- /dev/null +++ b/conformance/run-testcase.txt @@ -0,0 +1 @@ +Client Cancellation/HTTPVersion:2/Protocol:PROTOCOL_CONNECT/Codec:CODEC_PROTO/Compression:COMPRESSION_GZIP/TLS:false/bidi-stream/full-duplex/cancel-after-responses diff --git a/src/connect/client.py b/src/connect/client.py index c63261d..720f5c9 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -227,7 +227,7 @@ def on_request_send(r: httpcore.Request) -> None: conn.on_request_send(on_request_send) - await conn.send(request.message, request.timeout) + await conn.send(request.message, request.timeout, abort_event=request.abort_event) response = await recieve_unary_response(conn=conn, t=output) return response @@ -267,9 +267,9 @@ def on_request_send(r: httpcore.Request) -> None: conn.on_request_send(on_request_send) - await conn.send(request.messages, request.timeout) + await conn.send(request.messages, request.timeout, request.abort_event) - response = await recieve_stream_response(conn, output, request.spec) + response = await recieve_stream_response(conn, output, request.spec, request.abort_event) return response stream_func = apply_interceptors(_stream_func, options.interceptors) diff --git a/src/connect/connect.py b/src/connect/connect.py index 3d96a42..191e25d 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -1,6 +1,7 @@ """Defines the streaming handler connection interfaces and related utilities.""" import abc +import asyncio from collections.abc import AsyncIterator, Callable, Mapping from enum import Enum from http import HTTPMethod @@ -148,6 +149,7 @@ class StreamRequest[T](RequestCommon): _messages: AsyncIterator[T] timeout: float | None + abort_event: asyncio.Event | None = None def __init__( self, @@ -157,6 +159,7 @@ def __init__( headers: Headers | None = None, method: str | None = None, timeout: float | None = None, + abort_event: asyncio.Event | None = None, ) -> None: """Initialize a new Request instance. @@ -175,6 +178,7 @@ def __init__( super().__init__(spec, peer, headers, method) self._messages = messages if isinstance(messages, AsyncIterator) else aiterate([messages]) self.timeout = timeout + self.abort_event = abort_event @property def messages(self) -> AsyncIterator[T]: @@ -196,6 +200,7 @@ class UnaryRequest[T](RequestCommon): _message: T timeout: float | None + abort_event: asyncio.Event | None = None def __init__( self, @@ -205,6 +210,7 @@ def __init__( headers: Headers | None = None, method: str | None = None, timeout: float | None = None, + abort_event: asyncio.Event | None = None, ) -> None: """Initialize a new Request instance. @@ -223,6 +229,7 @@ def __init__( super().__init__(spec, peer, headers, method) self._message = message self.timeout = timeout + self.abort_event = abort_event @property def message(self) -> T: @@ -556,7 +563,7 @@ def request_headers(self) -> Headers: raise NotImplementedError() @abc.abstractmethod - async def send(self, message: Any, timeout: float | None) -> bytes: + async def send(self, message: Any, timeout: float | None, abort_event: asyncio.Event | None) -> bytes: """Send a message.""" raise NotImplementedError() @@ -594,7 +601,7 @@ def peer(self) -> Peer: raise NotImplementedError() @abc.abstractmethod - def receive(self, message: Any) -> AsyncIterator[Any]: + def receive(self, message: Any, abort_event: asyncio.Event | None) -> AsyncIterator[Any]: """Receives a message and processes it.""" raise NotImplementedError() @@ -605,7 +612,9 @@ def request_headers(self) -> Headers: raise NotImplementedError() @abc.abstractmethod - async def send(self, messages: AsyncIterator[Any], timeout: float | None) -> None: + async def send( + self, messages: AsyncIterator[Any], timeout: float | None, abort_event: asyncio.Event | None + ) -> None: """Send a stream of messages.""" raise NotImplementedError() @@ -763,7 +772,9 @@ async def recieve_unary_response[T](conn: UnaryClientConn, t: type[T]) -> UnaryR return UnaryResponse(message, conn.response_headers, conn.response_trailers) -async def recieve_stream_response[T](conn: StreamingClientConn, t: type[T], spec: Spec) -> StreamResponse[T]: +async def recieve_stream_response[T]( + conn: StreamingClientConn, t: type[T], spec: Spec, abort_event: asyncio.Event | None +) -> StreamResponse[T]: """Receive a stream response from a streaming client connection. Args: @@ -778,7 +789,7 @@ async def recieve_stream_response[T](conn: StreamingClientConn, t: type[T], spec if spec.stream_type == StreamType.ClientStream: count = 0 single_message: T | None = None - async for message in conn.receive(t): + async for message in conn.receive(t, abort_event): single_message = message count += 1 @@ -792,7 +803,7 @@ async def recieve_stream_response[T](conn: StreamingClientConn, t: type[T], spec return StreamResponse(aiterate([single_message]), conn.response_headers, conn.response_trailers) else: - return StreamResponse(conn.receive(t), conn.response_headers, conn.response_trailers) + return StreamResponse(conn.receive(t, abort_event), conn.response_headers, conn.response_trailers) async def receive_unary_message[T](conn: ReceiveConn, t: type[T]) -> T: diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 085beae..9fca352 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -1,5 +1,6 @@ """Provides classes and functions for handling protocol connections.""" +import asyncio import base64 import contextlib import json @@ -1719,18 +1720,26 @@ def on_request_send(self, fn: EventHook) -> None: """ self._event_hooks["request"].append(fn) - async def receive(self, message: Any) -> AsyncIterator[Any]: + async def receive(self, message: Any, abort_event: asyncio.Event | None = None) -> AsyncIterator[Any]: """Asynchronously receives and processes a message. Args: message (Any): The message to be processed. + abort_event (asyncio.Event | None): Event to signal abortion of the operation. Yields: Any: Objects obtained from unmarshaling the message. + Raises: + ConnectError: If stream is malformed or aborted. + """ end_stream_received = False + async for obj, end in self.unmarshaler.unmarshal(message): + if abort_event and abort_event.is_set(): + raise ConnectError("receive operation aborted", Code.CANCELED) + if end: if end_stream_received: raise ConnectError("received extra end stream message", Code.INVALID_ARGUMENT) @@ -1751,12 +1760,20 @@ async def receive(self, message: Any) -> AsyncIterator[Any]: if end_stream_received: raise ConnectError("received message after end stream", Code.INVALID_ARGUMENT) + if abort_event and abort_event.is_set(): + raise ConnectError("receive operation aborted", Code.CANCELED) + yield obj + if abort_event and abort_event.is_set(): + raise ConnectError("receive operation aborted", Code.CANCELED) + if not end_stream_received: raise ConnectError("missing end stream message", Code.INVALID_ARGUMENT) - async def send(self, messages: AsyncIterator[Any], timeout: float | None) -> None: + async def send( + self, messages: AsyncIterator[Any], timeout: float | None, abort_event: asyncio.Event | None + ) -> None: """Send a series of messages asynchronously. This method marshals the provided messages, constructs an HTTP POST request, @@ -1774,6 +1791,9 @@ async def send(self, messages: AsyncIterator[Any], timeout: float | None) -> Non Exception: If there is an error during the request or response handling. """ + if abort_event and abort_event.is_set(): + raise ConnectError("request aborted", Code.CANCELED) + extensions = {} if timeout: extensions["timeout"] = {"read": timeout} @@ -1802,7 +1822,21 @@ async def send(self, messages: AsyncIterator[Any], timeout: float | None) -> Non hook(request) with map_httpcore_exceptions(): - response = await self.session.pool.handle_async_request(request) + if not abort_event: + response = await self.session.pool.handle_async_request(request) + else: + request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) + abort_task = asyncio.create_task(abort_event.wait()) + + done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) + + if abort_task in done: + request_task.cancel() + raise ConnectError("request aborted", Code.CANCELED) + + response = await request_task + + abort_task.cancel() for hook in self._event_hooks["response"]: hook(response) @@ -1983,7 +2017,7 @@ def on_request_send(self, fn: EventHook) -> None: """ self._event_hooks["request"].append(fn) - async def send(self, message: Any, timeout: float | None) -> bytes: + async def send(self, message: Any, timeout: float | None, abort_event: asyncio.Event | None) -> bytes: """Send a message asynchronously and returns the marshaled data. Args: @@ -1997,6 +2031,9 @@ async def send(self, message: Any, timeout: float | None) -> bytes: Exception: If the response validation fails. """ + if abort_event and abort_event.is_set(): + raise ConnectError("request aborted", Code.CANCELED) + extensions = {} if timeout: extensions["timeout"] = {"read": timeout} @@ -2046,7 +2083,21 @@ async def send(self, message: Any, timeout: float | None) -> bytes: hook(request) with map_httpcore_exceptions(): - response = await self.session.pool.handle_async_request(request=request) + if not abort_event: + response = await self.session.pool.handle_async_request(request=request) + else: + request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) + abort_task = asyncio.create_task(abort_event.wait()) + + done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) + + if abort_task in done: + request_task.cancel() + raise ConnectError("request aborted", Code.CANCELED) + + response = await request_task + + abort_task.cancel() for hook in self._event_hooks["response"]: hook(response) From 03f472137483aad89c15dda0c76fd7d02807e537 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Mon, 14 Apr 2025 22:15:16 +0900 Subject: [PATCH 02/16] protocol_connect: use queue for client stream --- src/connect/connect.py | 69 +++++++++++++++++++++++++-------- src/connect/protocol_connect.py | 57 +++++++++++++-------------- 2 files changed, 80 insertions(+), 46 deletions(-) diff --git a/src/connect/connect.py b/src/connect/connect.py index 191e25d..f6b3998 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -601,7 +601,7 @@ def peer(self) -> Peer: raise NotImplementedError() @abc.abstractmethod - def receive(self, message: Any, abort_event: asyncio.Event | None) -> AsyncIterator[Any]: + async def receive(self, message: Any, queue: asyncio.Queue[Any]) -> None: """Receives a message and processes it.""" raise NotImplementedError() @@ -786,24 +786,61 @@ async def recieve_stream_response[T]( StreamResponse[T]: The stream response containing the received data, response headers, and response trailers. """ - if spec.stream_type == StreamType.ClientStream: - count = 0 - single_message: T | None = None - async for message in conn.receive(t, abort_event): - single_message = message - count += 1 + queue: asyncio.Queue[T] = asyncio.Queue() - if single_message is None: - raise ConnectError("ClientStream should receive one message, but received none.", Code.UNIMPLEMENTED) + producer_task = asyncio.create_task(conn.receive(t, queue)) + if abort_event is None: + abort_event = asyncio.Event() - if count > 1: - raise ConnectError( - "ClientStream should only receive one message, but received multiple.", Code.UNIMPLEMENTED - ) + abort_task = asyncio.create_task(abort_event.wait()) - return StreamResponse(aiterate([single_message]), conn.response_headers, conn.response_trailers) - else: - return StreamResponse(conn.receive(t, abort_event), conn.response_headers, conn.response_trailers) + async def _iterate() -> AsyncIterator[T]: + try: + while not abort_task.done() and not producer_task.done(): + data = await queue.get() + + if data is None: + queue.task_done() + break + + yield data + queue.task_done() + finally: + if not producer_task.done(): + producer_task.cancel() + + await producer_task + + if not abort_task.done(): + abort_task.cancel() + + if producer_task.done() and (exc := producer_task.exception()): + raise exc + + return StreamResponse( + _iterate(), + conn.response_headers, + conn.response_trailers, + ) + + # if spec.stream_type == StreamType.ClientStream: + # count = 0 + # single_message: T | None = None + # async for message in conn.receive(t, abort_event): + # single_message = message + # count += 1 + + # if single_message is None: + # raise ConnectError("ClientStream should receive one message, but received none.", Code.UNIMPLEMENTED) + + # if count > 1: + # raise ConnectError( + # "ClientStream should only receive one message, but received multiple.", Code.UNIMPLEMENTED + # ) + + # return StreamResponse(aiterate([single_message]), conn.response_headers, conn.response_trailers) + # else: + # return StreamResponse(conn.receive(t, abort_event), conn.response_headers, conn.response_trailers) async def receive_unary_message[T](conn: ReceiveConn, t: type[T]) -> T: diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 9fca352..aaed5cb 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -5,6 +5,7 @@ import contextlib import json from collections.abc import ( + AsyncGenerator, AsyncIterable, AsyncIterator, Awaitable, @@ -1306,7 +1307,7 @@ def __init__( self._end_stream_error = None self._trailers = Headers() - async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: + async def unmarshal(self, message: Any) -> AsyncGenerator[tuple[Any, bool]]: """Asynchronously unmarshals messages from the stream. Args: @@ -1720,7 +1721,7 @@ def on_request_send(self, fn: EventHook) -> None: """ self._event_hooks["request"].append(fn) - async def receive(self, message: Any, abort_event: asyncio.Event | None = None) -> AsyncIterator[Any]: + async def receive(self, message: Any, queue: asyncio.Queue[Any]) -> None: """Asynchronously receives and processes a message. Args: @@ -1735,41 +1736,37 @@ async def receive(self, message: Any, abort_event: asyncio.Event | None = None) """ end_stream_received = False + unmarshal = self.unmarshaler.unmarshal(message) - async for obj, end in self.unmarshaler.unmarshal(message): - if abort_event and abort_event.is_set(): - raise ConnectError("receive operation aborted", Code.CANCELED) - - if end: - if end_stream_received: - raise ConnectError("received extra end stream message", Code.INVALID_ARGUMENT) - - end_stream_received = True - error = self.unmarshaler.end_stream_error - if error: - for key, value in self.response_headers.items(): - error.metadata[key] = value - error.metadata.update(self.unmarshaler.trailers.copy()) - raise error - - for key, value in self.unmarshaler.trailers.items(): - self.response_trailers[key] = value + try: + async for obj, end in unmarshal: + if end: + if end_stream_received: + raise ConnectError("received extra end stream message", Code.INVALID_ARGUMENT) - continue + end_stream_received = True + error = self.unmarshaler.end_stream_error + if error: + for key, value in self.response_headers.items(): + error.metadata[key] = value + error.metadata.update(self.unmarshaler.trailers.copy()) + raise error - if end_stream_received: - raise ConnectError("received message after end stream", Code.INVALID_ARGUMENT) + for key, value in self.unmarshaler.trailers.items(): + self.response_trailers[key] = value - if abort_event and abort_event.is_set(): - raise ConnectError("receive operation aborted", Code.CANCELED) + continue - yield obj + if end_stream_received: + raise ConnectError("received message after end stream", Code.INVALID_ARGUMENT) - if abort_event and abort_event.is_set(): - raise ConnectError("receive operation aborted", Code.CANCELED) + await queue.put(obj) - if not end_stream_received: - raise ConnectError("missing end stream message", Code.INVALID_ARGUMENT) + if not end_stream_received: + raise ConnectError("missing end stream message", Code.INVALID_ARGUMENT) + finally: + await unmarshal.aclose() + queue.put_nowait(None) async def send( self, messages: AsyncIterator[Any], timeout: float | None, abort_event: asyncio.Event | None From e30c2a6022d74955d061b1ef40afb0b1d89da4f9 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Tue, 15 Apr 2025 21:26:07 +0900 Subject: [PATCH 03/16] protocol_connect: fix abort_event cancellation --- conformance/client_config.yaml | 2 +- conformance/client_known_failing.yaml | 7 ++- conformance/client_runner.py | 11 +--- src/connect/connect.py | 69 ++++++------------------ src/connect/protocol_connect.py | 75 ++++++++++++++++----------- 5 files changed, 69 insertions(+), 95 deletions(-) diff --git a/conformance/client_config.yaml b/conformance/client_config.yaml index ebf7e0f..abd0452 100644 --- a/conformance/client_config.yaml +++ b/conformance/client_config.yaml @@ -22,4 +22,4 @@ features: supports_trailers: true supports_half_duplex_bidi_over_http1: true supports_connect_get: true - supports_message_receive_limit: true + supports_message_receive_limit: false diff --git a/conformance/client_known_failing.yaml b/conformance/client_known_failing.yaml index 046b5a9..cf66088 100644 --- a/conformance/client_known_failing.yaml +++ b/conformance/client_known_failing.yaml @@ -1,2 +1,7 @@ # Cancellation is not supported yet -Client Cancellation/** +Client Cancellation/**/bidi-stream/half-duplex/cancel-after-close-send +Client Cancellation/**/bidi-stream/half-duplex/cancel-before-close-send +Client Cancellation/**/bidi-stream/half-duplex/cancel-after-responses +Client Cancellation/**/bidi-stream/full-duplex/cancel-after-close-send +Client Cancellation/**/bidi-stream/full-duplex/cancel-before-close-send +Client Cancellation/**/bidi-stream/full-duplex/cancel-after-responses diff --git a/conformance/client_runner.py b/conformance/client_runner.py index 0e85c34..d6a7795 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -419,8 +419,6 @@ async def run_message(req: client_compat_pb2.ClientCompatRequest) -> None: error=client_compat_pb2.ClientErrorResult(message="".join(traceback.format_exception(e))), ) - await asyncio.sleep(3) - write_response(resp) async def read_requests() -> None: @@ -429,10 +427,5 @@ async def read_requests() -> None: task = loop.create_task(run_message(req)) tasks.append(task) - try: - loop.run_until_complete(read_requests()) - finally: - pending = asyncio.all_tasks(loop) - if pending: - loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) - loop.close() + loop.run_until_complete(read_requests()) + loop.close() diff --git a/src/connect/connect.py b/src/connect/connect.py index f6b3998..191e25d 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -601,7 +601,7 @@ def peer(self) -> Peer: raise NotImplementedError() @abc.abstractmethod - async def receive(self, message: Any, queue: asyncio.Queue[Any]) -> None: + def receive(self, message: Any, abort_event: asyncio.Event | None) -> AsyncIterator[Any]: """Receives a message and processes it.""" raise NotImplementedError() @@ -786,61 +786,24 @@ async def recieve_stream_response[T]( StreamResponse[T]: The stream response containing the received data, response headers, and response trailers. """ - queue: asyncio.Queue[T] = asyncio.Queue() - - producer_task = asyncio.create_task(conn.receive(t, queue)) - if abort_event is None: - abort_event = asyncio.Event() - - abort_task = asyncio.create_task(abort_event.wait()) - - async def _iterate() -> AsyncIterator[T]: - try: - while not abort_task.done() and not producer_task.done(): - data = await queue.get() - - if data is None: - queue.task_done() - break - - yield data - queue.task_done() - finally: - if not producer_task.done(): - producer_task.cancel() - - await producer_task - - if not abort_task.done(): - abort_task.cancel() - - if producer_task.done() and (exc := producer_task.exception()): - raise exc - - return StreamResponse( - _iterate(), - conn.response_headers, - conn.response_trailers, - ) - - # if spec.stream_type == StreamType.ClientStream: - # count = 0 - # single_message: T | None = None - # async for message in conn.receive(t, abort_event): - # single_message = message - # count += 1 + if spec.stream_type == StreamType.ClientStream: + count = 0 + single_message: T | None = None + async for message in conn.receive(t, abort_event): + single_message = message + count += 1 - # if single_message is None: - # raise ConnectError("ClientStream should receive one message, but received none.", Code.UNIMPLEMENTED) + if single_message is None: + raise ConnectError("ClientStream should receive one message, but received none.", Code.UNIMPLEMENTED) - # if count > 1: - # raise ConnectError( - # "ClientStream should only receive one message, but received multiple.", Code.UNIMPLEMENTED - # ) + if count > 1: + raise ConnectError( + "ClientStream should only receive one message, but received multiple.", Code.UNIMPLEMENTED + ) - # return StreamResponse(aiterate([single_message]), conn.response_headers, conn.response_trailers) - # else: - # return StreamResponse(conn.receive(t, abort_event), conn.response_headers, conn.response_trailers) + return StreamResponse(aiterate([single_message]), conn.response_headers, conn.response_trailers) + else: + return StreamResponse(conn.receive(t, abort_event), conn.response_headers, conn.response_trailers) async def receive_unary_message[T](conn: ReceiveConn, t: type[T]) -> T: diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index aaed5cb..6225204 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -5,7 +5,6 @@ import contextlib import json from collections.abc import ( - AsyncGenerator, AsyncIterable, AsyncIterator, Awaitable, @@ -1307,7 +1306,7 @@ def __init__( self._end_stream_error = None self._trailers = Headers() - async def unmarshal(self, message: Any) -> AsyncGenerator[tuple[Any, bool]]: + async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: """Asynchronously unmarshals messages from the stream. Args: @@ -1721,7 +1720,7 @@ def on_request_send(self, fn: EventHook) -> None: """ self._event_hooks["request"].append(fn) - async def receive(self, message: Any, queue: asyncio.Queue[Any]) -> None: + async def receive(self, message: Any, abort_event: asyncio.Event | None = None) -> AsyncIterator[Any]: """Asynchronously receives and processes a message. Args: @@ -1736,37 +1735,41 @@ async def receive(self, message: Any, queue: asyncio.Queue[Any]) -> None: """ end_stream_received = False - unmarshal = self.unmarshaler.unmarshal(message) - try: - async for obj, end in unmarshal: - if end: - if end_stream_received: - raise ConnectError("received extra end stream message", Code.INVALID_ARGUMENT) + async for obj, end in self.unmarshaler.unmarshal(message): + if abort_event and abort_event.is_set(): + raise ConnectError("receive operation aborted", Code.CANCELED) - end_stream_received = True - error = self.unmarshaler.end_stream_error - if error: - for key, value in self.response_headers.items(): - error.metadata[key] = value - error.metadata.update(self.unmarshaler.trailers.copy()) - raise error + if end: + if end_stream_received: + raise ConnectError("received extra end stream message", Code.INVALID_ARGUMENT) - for key, value in self.unmarshaler.trailers.items(): - self.response_trailers[key] = value + end_stream_received = True + error = self.unmarshaler.end_stream_error + if error: + for key, value in self.response_headers.items(): + error.metadata[key] = value + error.metadata.update(self.unmarshaler.trailers.copy()) + raise error - continue + for key, value in self.unmarshaler.trailers.items(): + self.response_trailers[key] = value - if end_stream_received: - raise ConnectError("received message after end stream", Code.INVALID_ARGUMENT) + continue - await queue.put(obj) + if end_stream_received: + raise ConnectError("received message after end stream", Code.INVALID_ARGUMENT) - if not end_stream_received: - raise ConnectError("missing end stream message", Code.INVALID_ARGUMENT) - finally: - await unmarshal.aclose() - queue.put_nowait(None) + if abort_event and abort_event.is_set(): + raise ConnectError("receive operation aborted", Code.CANCELED) + + yield obj + + if abort_event and abort_event.is_set(): + raise ConnectError("receive operation aborted", Code.CANCELED) + + if not end_stream_received: + raise ConnectError("missing end stream message", Code.INVALID_ARGUMENT) async def send( self, messages: AsyncIterator[Any], timeout: float | None, abort_event: asyncio.Event | None @@ -1829,11 +1832,16 @@ async def send( if abort_task in done: request_task.cancel() - raise ConnectError("request aborted", Code.CANCELED) + with contextlib.suppress(asyncio.CancelledError): + await request_task - response = await request_task + raise ConnectError("request aborted", Code.CANCELED) abort_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await abort_task + + response = await request_task for hook in self._event_hooks["response"]: hook(response) @@ -2090,11 +2098,16 @@ async def send(self, message: Any, timeout: float | None, abort_event: asyncio.E if abort_task in done: request_task.cancel() - raise ConnectError("request aborted", Code.CANCELED) + with contextlib.suppress(asyncio.CancelledError): + await request_task - response = await request_task + raise ConnectError("request aborted", Code.CANCELED) abort_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await abort_task + + response = await request_task for hook in self._event_hooks["response"]: hook(response) From 07057fe7b9599aa196bb94842e3f9ebdb500aaf1 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Tue, 15 Apr 2025 21:33:09 +0900 Subject: [PATCH 04/16] protocol_connect: remove useless abort check --- src/connect/protocol_connect.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 6225204..a9b4a60 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -1760,14 +1760,8 @@ async def receive(self, message: Any, abort_event: asyncio.Event | None = None) if end_stream_received: raise ConnectError("received message after end stream", Code.INVALID_ARGUMENT) - if abort_event and abort_event.is_set(): - raise ConnectError("receive operation aborted", Code.CANCELED) - yield obj - if abort_event and abort_event.is_set(): - raise ConnectError("receive operation aborted", Code.CANCELED) - if not end_stream_received: raise ConnectError("missing end stream message", Code.INVALID_ARGUMENT) From 93038ea030e78ee1c3ce8dfaf20700c0520ea10e Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Thu, 17 Apr 2025 23:54:30 +0900 Subject: [PATCH 05/16] connect: safety stream close --- conformance/client_config.yaml | 2 +- conformance/client_runner.py | 56 ++++++--- conformance/run-testcase.txt | 2 +- src/connect/connect.py | 35 ++++-- src/connect/protocol_connect.py | 205 ++++++++++++++++++-------------- src/connect/utils.py | 45 +++++-- 6 files changed, 216 insertions(+), 129 deletions(-) diff --git a/conformance/client_config.yaml b/conformance/client_config.yaml index abd0452..ce19d9b 100644 --- a/conformance/client_config.yaml +++ b/conformance/client_config.yaml @@ -14,7 +14,7 @@ features: - STREAM_TYPE_CLIENT_STREAM - STREAM_TYPE_SERVER_STREAM - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM - - STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM + # - STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM supports_h2c: true supports_tls: true diff --git a/conformance/client_runner.py b/conformance/client_runner.py index d6a7795..0171369 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -321,10 +321,13 @@ async def delayed_abort() -> None: asyncio.create_task(delayed_abort()) - async for message in resp.messages: - payloads.append(message.payload) - if len(payloads) == msg.cancel.after_num_responses: - abort_event.set() + try: + async for message in resp.messages: + payloads.append(message.payload) + if len(payloads) == msg.cancel.after_num_responses: + abort_event.set() + finally: + await resp.aclose() return client_compat_pb2.ClientCompatResponse( test_name=msg.test_name, @@ -340,6 +343,7 @@ async def delayed_abort() -> None: msg.stream_type == config_pb2.STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM or msg.stream_type == config_pb2.STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM ): + abort_event = asyncio.Event() if msg.request_delay_ms > 0: await asyncio.sleep(msg.request_delay_ms / 1000.0) @@ -352,14 +356,35 @@ async def delayed_abort() -> None: resp = await getattr(client, msg.method)( StreamRequest( - messages=reqs, - headers=header, - timeout=msg.timeout_ms / 1000, + messages=reqs, headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event ), ) - async for message in resp.messages: - payloads.append(message.payload) + if msg.HasField("cancel") and msg.cancel.HasField("before_close_send"): + abort_event.set() + + if msg.HasField("cancel") and msg.cancel.HasField("after_close_send_ms"): + + async def delayed_abort() -> None: + await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) + abort_event.set() + + asyncio.create_task(delayed_abort()) + + if ( + msg.HasField("cancel") + and msg.cancel.HasField("after_num_responses") + and msg.cancel.after_num_responses == 0 + ): + abort_event.set() + + try: + async for message in resp.messages: + payloads.append(message.payload) + if len(payloads) == msg.cancel.after_num_responses: + abort_event.set() + finally: + await resp.aclose() return client_compat_pb2.ClientCompatResponse( test_name=msg.test_name, @@ -403,12 +428,6 @@ async def delayed_abort() -> None: tracemalloc.start() - loop = asyncio.new_event_loop() - loop.set_debug(True) - asyncio.set_event_loop(loop) - - tasks = [] - async def run_message(req: client_compat_pb2.ClientCompatRequest) -> None: """Run the message handler for a given request.""" try: @@ -423,9 +442,8 @@ async def run_message(req: client_compat_pb2.ClientCompatRequest) -> None: async def read_requests() -> None: """Read requests from standard input and process them asynchronously.""" + loop = asyncio.get_event_loop() while req := await loop.run_in_executor(None, read_request): - task = loop.create_task(run_message(req)) - tasks.append(task) + loop.create_task(run_message(req)) - loop.run_until_complete(read_requests()) - loop.close() + asyncio.run(read_requests()) diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt index 87b39f7..e3581b0 100644 --- a/conformance/run-testcase.txt +++ b/conformance/run-testcase.txt @@ -1 +1 @@ -Client Cancellation/HTTPVersion:2/Protocol:PROTOCOL_CONNECT/Codec:CODEC_PROTO/Compression:COMPRESSION_GZIP/TLS:false/bidi-stream/full-duplex/cancel-after-responses +Client Cancellation/**/server-stream/cancel-after-responses diff --git a/src/connect/connect.py b/src/connect/connect.py index 191e25d..e09105f 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -2,7 +2,7 @@ import abc import asyncio -from collections.abc import AsyncIterator, Callable, Mapping +from collections.abc import AsyncIterable, AsyncIterator, Callable, Mapping from enum import Enum from http import HTTPMethod from typing import Any, Protocol, cast @@ -13,7 +13,7 @@ from connect.error import ConnectError from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel -from connect.utils import aiterate, get_callable_attribute +from connect.utils import AsyncIteratorByteStream, aiterate, get_callable_attribute class StreamType(Enum): @@ -293,23 +293,28 @@ def message(self) -> T: class StreamResponse[T](ResponseCommon): """Response class for handling responses.""" - _messages: AsyncIterator[T] + _messages: AsyncIterable[T] def __init__( self, - messages: AsyncIterator[T] | T, + messages: AsyncIterable[T] | T, headers: Headers | None = None, trailers: Headers | None = None, ) -> None: """Initialize the response with a message.""" super().__init__(headers, trailers) - self._messages = messages if isinstance(messages, AsyncIterator) else aiterate([messages]) + self._messages = messages if isinstance(messages, AsyncIterable) else aiterate([messages]) @property - def messages(self) -> AsyncIterator[T]: + def messages(self) -> AsyncIterable[T]: """Return the response message.""" return self._messages + async def aclose(self) -> None: + """Asynchronously close the response stream.""" + if hasattr(self._messages, "aclose"): + await self._messages.aclose() # type: ignore + class UnaryHandlerConn(abc.ABC): """Abstract base class for a streaming handler connection. @@ -482,7 +487,7 @@ def request_headers(self) -> Headers: raise NotImplementedError() @abc.abstractmethod - async def send(self, messages: AsyncIterator[Any]) -> None: + async def send(self, messages: AsyncIterable[Any]) -> None: """Send a stream of messages asynchronously. Args: @@ -584,6 +589,11 @@ def on_request_send(self, fn: Callable[..., Any]) -> None: """Handle the request send event.""" raise NotImplementedError() + @abc.abstractmethod + async def aclose(self) -> None: + """Asynchronously close the connection.""" + raise NotImplementedError() + class StreamingClientConn: """Abstract base class for a streaming client connection.""" @@ -635,6 +645,11 @@ def on_request_send(self, fn: Callable[..., Any]) -> None: """Handle the request send event.""" raise NotImplementedError() + @abc.abstractmethod + async def aclose(self) -> None: + """Asynchronously close the connection.""" + raise NotImplementedError() + class ReceiveConn(Protocol): """A protocol that defines the methods required for receiving connections.""" @@ -803,7 +818,11 @@ async def recieve_stream_response[T]( return StreamResponse(aiterate([single_message]), conn.response_headers, conn.response_trailers) else: - return StreamResponse(conn.receive(t, abort_event), conn.response_headers, conn.response_trailers) + return StreamResponse( + AsyncIteratorByteStream[T](conn.receive(t, abort_event), conn.aclose), + conn.response_headers, + conn.response_trailers, + ) async def receive_unary_message[T](conn: ReceiveConn, t: type[T]) -> T: diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index a9b4a60..1a88fe4 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -7,7 +7,6 @@ from collections.abc import ( AsyncIterable, AsyncIterator, - Awaitable, Callable, Mapping, ) @@ -56,7 +55,12 @@ ) from connect.request import Request from connect.session import AsyncClientSession -from connect.utils import AsyncByteStream, AsyncIteratorByteStream, aiterate, map_httpcore_exceptions +from connect.utils import ( + AsyncByteStream, + aiterate, + get_acallable_attribute, + map_httpcore_exceptions, +) from connect.version import __version__ from connect.writer import ServerResponseWriter @@ -424,7 +428,7 @@ class ConnectUnaryUnmarshaler: codec: Codec | None read_max_bytes: int compression: Compression | None - stream: AsyncIteratorByteStream | None + stream: AsyncIterable[bytes] | None def __init__( self, @@ -445,7 +449,7 @@ def __init__( self.codec = codec self.read_max_bytes = read_max_bytes self.compression = compression - self.stream = AsyncIteratorByteStream(stream) if stream else None + self.stream = stream async def unmarshal(self, message: Any) -> Any: """Asynchronously unmarshals a given message using the provided unmarshal function and codec. @@ -488,35 +492,43 @@ async def unmarshal_func(self, message: Any, func: Callable[[bytes, Any], Any]) chunks: list[bytes] = [] bytes_read = 0 - try: - async for chunk in self.stream: - chunk_size = len(chunk) - bytes_read += chunk_size - if self.read_max_bytes > 0 and bytes_read > self.read_max_bytes: - raise ConnectError( - f"message size {bytes_read} is larger than configured max {self.read_max_bytes}", - Code.RESOURCE_EXHAUSTED, - ) + async for chunk in self.stream: + chunk_size = len(chunk) + bytes_read += chunk_size + if self.read_max_bytes > 0 and bytes_read > self.read_max_bytes: + raise ConnectError( + f"message size {bytes_read} is larger than configured max {self.read_max_bytes}", + Code.RESOURCE_EXHAUSTED, + ) - chunks.append(chunk) + chunks.append(chunk) - data = b"".join(chunks) + data = b"".join(chunks) - if len(data) > 0 and self.compression: - data = self.compression.decompress(data, self.read_max_bytes) + if len(data) > 0 and self.compression: + data = self.compression.decompress(data, self.read_max_bytes) - try: - obj = func(data, message) - except Exception as e: - raise ConnectError( - f"unmarshal message: {str(e)}", - Code.INVALID_ARGUMENT, - ) from e - finally: - await self.stream.aclose() + try: + obj = func(data, message) + except Exception as e: + raise ConnectError( + f"unmarshal message: {str(e)}", + Code.INVALID_ARGUMENT, + ) from e return obj + async def aclose(self) -> None: + """Asynchronously close the stream if it is set. + + This method is intended to be called when the stream is no longer needed + to release any associated resources. + + """ + aclose = get_acallable_attribute(self.stream, "aclose") + if aclose: + await aclose() + class ConnectUnaryMarshaler: """ConnectUnaryMarshaler is responsible for serializing and optionally compressing messages. @@ -1104,38 +1116,44 @@ class ResponseAsyncByteStream(AsyncByteStream): """An asynchronous byte stream for reading and writing byte chunks.""" aiterator: AsyncIterable[bytes] | None - aclose_func: Callable[..., Awaitable[None]] | None + _closed: bool def __init__( self, aiterator: AsyncIterable[bytes] | None = None, - aclose_func: Callable[..., Awaitable[None]] | None = None, ) -> None: """Initialize the protocol connect instance. Args: aiterator (AsyncIterable[bytes] | None): An optional asynchronous iterable of bytes. - aclose_func (Callable[..., Awaitable[None]] | None): An optional asynchronous close function. Returns: None """ self.aiterator = aiterator - self.aclose_func = aclose_func + self._closed = False async def __aiter__(self) -> AsyncIterator[bytes]: """Asynchronous iterator method to read byte chunks from the stream.""" - if self.aiterator is not None: - with map_httpcore_exceptions(): - async for chunk in self.aiterator: - yield chunk + if self.aiterator: + try: + with map_httpcore_exceptions(): + async for chunk in self.aiterator: + yield chunk + except BaseException as exc: + await self.aclose() + raise exc async def aclose(self) -> None: """Asynchronously close the stream.""" - if self.aclose_func: + if not self._closed and self.aiterator: + aclose = get_acallable_attribute(self.aiterator, "aclose") + if not aclose: + return + with map_httpcore_exceptions(): - await self.aclose_func() + await aclose() class ConnectStreamingMarshaler: @@ -1169,7 +1187,7 @@ def __init__( self.send_max_bytes = send_max_bytes self.compression = compression - async def marshal(self, messages: AsyncIterator[Any]) -> AsyncIterator[bytes]: + async def marshal(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: """Asynchronously marshals and compresses messages from an asynchronous iterator. Args: @@ -1277,7 +1295,7 @@ class ConnectStreamingUnmarshaler: codec: Codec | None read_max_bytes: int compression: Compression | None - stream: AsyncIteratorByteStream | None + stream: AsyncIterable[bytes] | None buffer: bytes _end_stream_error: ConnectError | None _trailers: Headers @@ -1301,7 +1319,7 @@ def __init__( self.codec = codec self.read_max_bytes = read_max_bytes self.compression = compression - self.stream = AsyncIteratorByteStream(stream) if stream else None + self.stream = stream self.buffer = b"" self._end_stream_error = None self._trailers = Headers() @@ -1326,58 +1344,56 @@ async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: if self.codec is None: raise ConnectError("codec is not set", Code.INTERNAL) - try: - async for chunk in self.stream: - self.buffer += chunk - - while True: - env, data_len = Envelope.decode(self.buffer) - if env is None: - break + async for chunk in self.stream: + self.buffer += chunk - if self.read_max_bytes > 0 and data_len > self.read_max_bytes: - raise ConnectError( - f"message size {data_len} is larger than configured readMaxBytes {self.read_max_bytes}", - Code.RESOURCE_EXHAUSTED, - ) + while True: + env, data_len = Envelope.decode(self.buffer) + if env is None: + break - self.buffer = self.buffer[5 + data_len :] - - if env.is_set(EnvelopeFlags.compressed): - if not self.compression: - raise ConnectError( - "protocol error: sent compressed message without compression support", Code.INTERNAL - ) + if self.read_max_bytes > 0 and data_len > self.read_max_bytes: + raise ConnectError( + f"message size {data_len} is larger than configured readMaxBytes {self.read_max_bytes}", + Code.RESOURCE_EXHAUSTED, + ) - env.data = self.compression.decompress(env.data, self.read_max_bytes) + self.buffer = self.buffer[5 + data_len :] - if env.is_set(EnvelopeFlags.end_stream): - error, trailers = end_stream_from_bytes(env.data) - self._end_stream_error = error - self._trailers = trailers - end = True - obj = None - else: - try: - obj = self.codec.unmarshal(env.data, message) - except Exception as e: - raise ConnectError( - f"unmarshal message: {str(e)}", - Code.INVALID_ARGUMENT, - ) from e + if env.is_set(EnvelopeFlags.compressed): + if not self.compression: + raise ConnectError( + "protocol error: sent compressed message without compression support", Code.INTERNAL + ) - end = False + env.data = self.compression.decompress(env.data, self.read_max_bytes) + + if env.is_set(EnvelopeFlags.end_stream): + error, trailers = end_stream_from_bytes(env.data) + self._end_stream_error = error + self._trailers = trailers + end = True + obj = None + else: + try: + obj = self.codec.unmarshal(env.data, message) + except Exception as e: + raise ConnectError( + f"unmarshal message: {str(e)}", + Code.INVALID_ARGUMENT, + ) from e - yield obj, end + end = False - if len(self.buffer) > 0: - header = Envelope.decode_header(self.buffer) - if header: - message = f"protocol error: promised {header[1]} bytes in enveloped message, got {len(self.buffer) - 5} bytes" - raise ConnectError(message, Code.INVALID_ARGUMENT) + yield obj, end - finally: - await self.stream.aclose() + if len(self.buffer) > 0: + header = Envelope.decode_header(self.buffer) + if header: + message = ( + f"protocol error: promised {header[1]} bytes in enveloped message, got {len(self.buffer) - 5} bytes" + ) + raise ConnectError(message, Code.INVALID_ARGUMENT) @property def trailers(self) -> Headers: @@ -1402,6 +1418,11 @@ def end_stream_error(self) -> ConnectError | None: """ return self._end_stream_error + async def aclose(self) -> None: + aclose = get_acallable_attribute(self.stream, "aclose") + if aclose: + await aclose() + class ConnectStreamingHandlerConn(StreamingHandlerConn): """ConnectStreamingHandlerConn is a class that handles streaming connections for the Connect protocol. @@ -1507,7 +1528,7 @@ def request_headers(self) -> Headers: """ return self._request_headers - async def send(self, messages: AsyncIterator[Any]) -> None: + async def send(self, messages: AsyncIterable[Any]) -> None: """Send a stream of messages asynchronously. This method marshals the provided messages and sends them using the writer. @@ -1840,9 +1861,8 @@ async def send( for hook in self._event_hooks["response"]: hook(response) - self.unmarshaler.stream = AsyncIteratorByteStream( - ResponseAsyncByteStream(aiterator=response.aiter_stream(), aclose_func=response.aclose) - ) + assert isinstance(response.stream, AsyncIterable) + self.unmarshaler.stream = ResponseAsyncByteStream(aiterator=response.stream) await self._validate_response(response) @@ -1885,6 +1905,9 @@ async def _validate_response(self, response: httpcore.Response) -> None: self.unmarshaler.compression = get_compresion_from_name(compression, self.compressions) self._response_headers.update(response_headers) + async def aclose(self) -> None: + await self.unmarshaler.aclose() + class ConnectUnaryClientConn(UnaryClientConn): """A client connection for unary RPCs using the Connect protocol. @@ -2106,9 +2129,8 @@ async def send(self, message: Any, timeout: float | None, abort_event: asyncio.E for hook in self._event_hooks["response"]: hook(response) - self.unmarshaler.stream = AsyncIteratorByteStream( - ResponseAsyncByteStream(response.aiter_stream(), response.aclose) - ) + assert isinstance(response.stream, AsyncIterable) + self.unmarshaler.stream = ResponseAsyncByteStream(response.stream) await self._validate_response(response) @@ -2185,6 +2207,9 @@ def json_ummarshal(data: bytes, _message: Any) -> Any: wire_error.metadata.update(self._response_trailers) raise wire_error + async def aclose(self) -> None: + await self.unmarshaler.aclose() + @property def event_hooks(self) -> dict[str, list[EventHook]]: """Return the event hooks. diff --git a/src/connect/utils.py b/src/connect/utils.py index 9ef0d80..dbf0219 100644 --- a/src/connect/utils.py +++ b/src/connect/utils.py @@ -87,10 +87,23 @@ def get_callable_attribute(obj: object, attr: str) -> typing.Callable[..., typin typing.Callable[..., typing.Any] | None: The callable attribute if it exists and is callable, otherwise None. """ - if hasattr(obj, attr) and callable(getattr(obj, attr)): - return getattr(obj, attr) + try: + attr_value = getattr(obj, attr) + if callable(attr_value): + return attr_value + return None + except AttributeError: + return None + - return None +def get_acallable_attribute(obj: object, attr: str) -> typing.Callable[..., typing.Awaitable[typing.Any]] | None: + try: + attr_value = getattr(obj, attr) + if callable(attr_value) and is_async_callable(attr_value): + return attr_value + return None + except AttributeError: + return None def get_route_path(scope: Scope) -> str: @@ -192,7 +205,7 @@ def __init__(self) -> None: super().__init__("Stream has already been consumed.") -class AsyncIteratorByteStream: +class AsyncIteratorByteStream[T]: """An asynchronous iterator for byte streams. This class wraps an asynchronous iterable of bytes and provides an @@ -206,7 +219,13 @@ class AsyncIteratorByteStream: """ - def __init__(self, stream: typing.AsyncIterable[bytes]) -> None: + _stream: typing.AsyncIterable[T] + _is_stream_consumed: bool + aclose_func: Callable[..., Awaitable[None]] | None + + def __init__( + self, stream: typing.AsyncIterable[T], aclose_func: Callable[..., Awaitable[None]] | None = None + ) -> None: """Initialize the utility with an asynchronous byte stream. Args: @@ -215,8 +234,9 @@ def __init__(self, stream: typing.AsyncIterable[bytes]) -> None: """ self._stream = stream self._is_stream_consumed = False + self.aclose_func = aclose_func - async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async def __aiter__(self) -> typing.AsyncIterator[T]: """Asynchronously iterates over the stream and yields parts of it. Yields: @@ -230,8 +250,12 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]: raise StreamConsumedError() self._is_stream_consumed = True - async for part in self._stream: - yield part + try: + async for part in self._stream: + yield part + except BaseException as exc: + await self.aclose() + raise exc async def aclose(self) -> None: """Asynchronously closes the stream if it has an `aclose` method. @@ -244,8 +268,9 @@ async def aclose(self) -> None: None """ - if isinstance(self._stream, AsyncByteStream): - await self._stream.aclose() + if self.aclose_func: + await self.aclose_func() + return async def aiterate[T](iterable: typing.Iterable[T]) -> typing.AsyncIterator[T]: From e46941a992a5c81d261fbb9b94c0476fe64a40f8 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Fri, 18 Apr 2025 19:03:38 +0900 Subject: [PATCH 06/16] connect: change to use async contextmanager for streaming --- conformance/client_runner.py | 64 +++++------ conformance/run-testcase.txt | 1 - src/connect/client.py | 34 ++++-- src/connect/connect.py | 5 - src/connect/protocol_connect.py | 44 ++++---- tests/test_streaming_connect_client.py | 147 ++++++++++++------------- 6 files changed, 145 insertions(+), 150 deletions(-) diff --git a/conformance/client_runner.py b/conformance/client_runner.py index 0171369..a32241c 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -277,14 +277,13 @@ async def delayed_abort() -> None: asyncio.create_task(delayed_abort()) - resp = await getattr(client, msg.method)( + async with getattr(client, msg.method)( StreamRequest( messages=_reqs(), headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event ), - ) - - async for message in resp.messages: - payloads.append(message.payload) + ) as resp: + async for message in resp.messages: + payloads.append(message.payload) return client_compat_pb2.ClientCompatResponse( test_name=msg.test_name, @@ -307,27 +306,23 @@ async def delayed_abort() -> None: else: header[h.name.lower()] = ", ".join(h.value) - resp = await getattr(client, msg.method)( + async with getattr(client, msg.method)( StreamRequest( messages=reqs, headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event ), - ) - - if msg.HasField("cancel") and msg.cancel.HasField("after_close_send_ms"): + ) as resp: + if msg.HasField("cancel") and msg.cancel.HasField("after_close_send_ms"): - async def delayed_abort() -> None: - await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) - abort_event.set() + async def delayed_abort() -> None: + await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) + abort_event.set() - asyncio.create_task(delayed_abort()) + asyncio.create_task(delayed_abort()) - try: async for message in resp.messages: payloads.append(message.payload) if len(payloads) == msg.cancel.after_num_responses: abort_event.set() - finally: - await resp.aclose() return client_compat_pb2.ClientCompatResponse( test_name=msg.test_name, @@ -354,37 +349,33 @@ async def delayed_abort() -> None: else: header[h.name.lower()] = ", ".join(h.value) - resp = await getattr(client, msg.method)( + async with getattr(client, msg.method)( StreamRequest( messages=reqs, headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event ), - ) - - if msg.HasField("cancel") and msg.cancel.HasField("before_close_send"): - abort_event.set() + ) as resp: + if msg.HasField("cancel") and msg.cancel.HasField("before_close_send"): + abort_event.set() - if msg.HasField("cancel") and msg.cancel.HasField("after_close_send_ms"): + if msg.HasField("cancel") and msg.cancel.HasField("after_close_send_ms"): - async def delayed_abort() -> None: - await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) - abort_event.set() + async def delayed_abort() -> None: + await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) + abort_event.set() - asyncio.create_task(delayed_abort()) + asyncio.create_task(delayed_abort()) - if ( - msg.HasField("cancel") - and msg.cancel.HasField("after_num_responses") - and msg.cancel.after_num_responses == 0 - ): - abort_event.set() + if ( + msg.HasField("cancel") + and msg.cancel.HasField("after_num_responses") + and msg.cancel.after_num_responses == 0 + ): + abort_event.set() - try: async for message in resp.messages: payloads.append(message.payload) if len(payloads) == msg.cancel.after_num_responses: abort_event.set() - finally: - await resp.aclose() return client_compat_pb2.ClientCompatResponse( test_name=msg.test_name, @@ -424,9 +415,6 @@ async def delayed_abort() -> None: if __name__ == "__main__": if "--debug" in sys.argv: logging.debug("Debug mode enabled") - import tracemalloc - - tracemalloc.start() async def run_message(req: client_compat_pb2.ClientCompatRequest) -> None: """Run the message handler for a given request.""" diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt index e3581b0..e69de29 100644 --- a/conformance/run-testcase.txt +++ b/conformance/run-testcase.txt @@ -1 +0,0 @@ -Client Cancellation/**/server-stream/cancel-after-responses diff --git a/src/connect/client.py b/src/connect/client.py index 720f5c9..1ef9e77 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -3,7 +3,8 @@ These classes allow making unary calls to a specified URL with given request and response types. """ -from collections.abc import Awaitable, Callable +import contextlib +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any import httpcore @@ -299,7 +300,8 @@ async def call_unary(self, request: UnaryRequest[T_Request]) -> UnaryResponse[T_ """ return await self._call_unary(request) - async def call_server_stream(self, request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: + @contextlib.asynccontextmanager + async def call_server_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]: """Asynchronously calls a server streaming RPC (Remote Procedure Call) with the given request. Args: @@ -309,9 +311,14 @@ async def call_server_stream(self, request: StreamRequest[T_Request]) -> StreamR UnaryResponse[T_Response]: The response object containing the data received from the server. """ - return await self._call_stream(StreamType.ServerStream, request) - - async def call_client_stream(self, request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: + response = await self._call_stream(StreamType.ServerStream, request) + try: + yield response + finally: + await response.aclose() + + @contextlib.asynccontextmanager + async def call_client_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]: """Asynchronously calls a client stream and yields responses. This method sends a stream request to the client and asynchronously @@ -324,9 +331,14 @@ async def call_client_stream(self, request: StreamRequest[T_Request]) -> StreamR UnaryResponse[T_Response]: The response from the client stream. """ - return await self._call_stream(StreamType.ClientStream, request) - - async def call_bidi_stream(self, request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: + response = await self._call_stream(StreamType.ClientStream, request) + try: + yield response + finally: + await response.aclose() + + @contextlib.asynccontextmanager + async def call_bidi_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]: """Initiate a bidirectional streaming call. This method establishes a bidirectional stream between the client and the server, @@ -344,4 +356,8 @@ async def call_bidi_stream(self, request: StreamRequest[T_Request]) -> StreamRes Any exceptions raised during the streaming call will propagate to the caller. """ - return await self._call_stream(StreamType.BiDiStream, request) + response = await self._call_stream(StreamType.BiDiStream, request) + try: + yield response + finally: + await response.aclose() diff --git a/src/connect/connect.py b/src/connect/connect.py index e09105f..1104665 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -589,11 +589,6 @@ def on_request_send(self, fn: Callable[..., Any]) -> None: """Handle the request send event.""" raise NotImplementedError() - @abc.abstractmethod - async def aclose(self) -> None: - """Asynchronously close the connection.""" - raise NotImplementedError() - class StreamingClientConn: """Abstract base class for a streaming client connection.""" diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 1a88fe4..bb96c72 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -492,29 +492,32 @@ async def unmarshal_func(self, message: Any, func: Callable[[bytes, Any], Any]) chunks: list[bytes] = [] bytes_read = 0 - async for chunk in self.stream: - chunk_size = len(chunk) - bytes_read += chunk_size - if self.read_max_bytes > 0 and bytes_read > self.read_max_bytes: - raise ConnectError( - f"message size {bytes_read} is larger than configured max {self.read_max_bytes}", - Code.RESOURCE_EXHAUSTED, - ) + try: + async for chunk in self.stream: + chunk_size = len(chunk) + bytes_read += chunk_size + if self.read_max_bytes > 0 and bytes_read > self.read_max_bytes: + raise ConnectError( + f"message size {bytes_read} is larger than configured max {self.read_max_bytes}", + Code.RESOURCE_EXHAUSTED, + ) - chunks.append(chunk) + chunks.append(chunk) - data = b"".join(chunks) + data = b"".join(chunks) - if len(data) > 0 and self.compression: - data = self.compression.decompress(data, self.read_max_bytes) + if len(data) > 0 and self.compression: + data = self.compression.decompress(data, self.read_max_bytes) - try: - obj = func(data, message) - except Exception as e: - raise ConnectError( - f"unmarshal message: {str(e)}", - Code.INVALID_ARGUMENT, - ) from e + try: + obj = func(data, message) + except Exception as e: + raise ConnectError( + f"unmarshal message: {str(e)}", + Code.INVALID_ARGUMENT, + ) from e + finally: + await self.aclose() return obj @@ -2207,9 +2210,6 @@ def json_ummarshal(data: bytes, _message: Any) -> Any: wire_error.metadata.update(self._response_trailers) raise wire_error - async def aclose(self) -> None: - await self.unmarshaler.aclose() - @property def event_hooks(self) -> dict[str, list[EventHook]]: """Return the event hooks. diff --git a/tests/test_streaming_connect_client.py b/tests/test_streaming_connect_client.py index 0f4b3d3..365c9dd 100644 --- a/tests/test_streaming_connect_client.py +++ b/tests/test_streaming_connect_client.py @@ -69,11 +69,11 @@ async def test_server_streaming(hypercorn_server: ServerConfig) -> None: client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] - async for message in response.messages: - assert message.name in want - want.remove(message.name) + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] + async for message in response.messages: + assert message.name in want + want.remove(message.name) async def server_streaming_end_stream_error(scope: Scope, receive: Receive, send: Send) -> None: @@ -120,18 +120,18 @@ async def test_server_streaming_end_stream_error(hypercorn_server: ServerConfig) client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] - with pytest.raises(ConnectError) as excinfo: - async for message in response.messages: - assert message.name in want - want.remove(message.name) + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] + with pytest.raises(ConnectError) as excinfo: + async for message in response.messages: + assert message.name in want + want.remove(message.name) - assert excinfo.value.code == Code.UNAVAILABLE - assert excinfo.value.metadata["acme-operation-cost"] == "237" - assert excinfo.value.raw_message == "" - assert len(excinfo.value.details) == 0 - assert excinfo.value.wire_error is True + assert excinfo.value.code == Code.UNAVAILABLE + assert excinfo.value.metadata["acme-operation-cost"] == "237" + assert excinfo.value.raw_message == "" + assert len(excinfo.value.details) == 0 + assert excinfo.value.wire_error is True async def server_streaming_received_message_after_end_stream(scope: Scope, receive: Receive, send: Send) -> None: @@ -183,16 +183,16 @@ async def test_server_streaming_received_message_after_end_stream(hypercorn_serv client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] - with pytest.raises(ConnectError) as excinfo: - async for message in response.messages: - assert message.name in want - want.remove(message.name) + with pytest.raises(ConnectError) as excinfo: + async for message in response.messages: + assert message.name in want + want.remove(message.name) - assert excinfo.value.code == Code.INVALID_ARGUMENT - assert excinfo.value.raw_message == "received message after end stream" + assert excinfo.value.code == Code.INVALID_ARGUMENT + assert excinfo.value.raw_message == "received message after end stream" async def server_streaming_received_extra_end_stream(scope: Scope, receive: Receive, send: Send) -> None: @@ -248,16 +248,16 @@ async def test_server_streaming_received_extra_end_stream(hypercorn_server: Serv client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] - with pytest.raises(ConnectError) as excinfo: - async for message in response.messages: - assert message.name in want - want.remove(message.name) + with pytest.raises(ConnectError) as excinfo: + async for message in response.messages: + assert message.name in want + want.remove(message.name) - assert excinfo.value.code == Code.INVALID_ARGUMENT - assert excinfo.value.raw_message == "received extra end stream message" + assert excinfo.value.code == Code.INVALID_ARGUMENT + assert excinfo.value.raw_message == "received extra end stream message" async def server_streaming_not_received_end_stream(scope: Scope, receive: Receive, send: Send) -> None: @@ -299,16 +299,16 @@ async def test_server_streaming_not_received_end_stream(hypercorn_server: Server client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] - with pytest.raises(ConnectError) as excinfo: - async for message in response.messages: - assert message.name in want - want.remove(message.name) + with pytest.raises(ConnectError) as excinfo: + async for message in response.messages: + assert message.name in want + want.remove(message.name) - assert excinfo.value.code == Code.INVALID_ARGUMENT - assert excinfo.value.raw_message == "missing end stream message" + assert excinfo.value.code == Code.INVALID_ARGUMENT + assert excinfo.value.raw_message == "missing end stream message" async def server_streaming_response_envelope_message_compression(scope: Scope, receive: Receive, send: Send) -> None: @@ -356,11 +356,11 @@ async def test_server_streaming_response_envelope_message_compression(hypercorn_ client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] - async for message in response.messages: - assert message.name in want - want.remove(message.name) + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] + async for message in response.messages: + assert message.name in want + want.remove(message.name) async def server_streaming_request_envelope_message_compression(scope: Scope, receive: Receive, send: Send) -> None: @@ -421,11 +421,11 @@ async def test_server_streaming_request_envelope_message_compression(hypercorn_s ) ping_request = StreamRequest(messages=PingRequest(name="Bob")) - response = await client.call_server_stream(ping_request) - want = ["Hi Bob.", "I'm Eliza."] - async for message in response.messages: - assert message.name in want - want.remove(message.name) + async with client.call_server_stream(ping_request) as response: + want = ["Hi Bob.", "I'm Eliza."] + async for message in response.messages: + assert message.name in want + want.remove(message.name) @pytest.mark.asyncio() @@ -481,14 +481,13 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: ping_request = StreamRequest(messages=PingRequest(name="test")) - await client.call_server_stream(ping_request) + async with client.call_server_stream(ping_request): + assert len(ephemeral_files) == 2 + for i, ephemeral_file in enumerate(reversed(ephemeral_files)): + ephemeral_file.seek(0) + assert ephemeral_file.read() == f"interceptor: {i + 1}".encode() - assert len(ephemeral_files) == 2 - for i, ephemeral_file in enumerate(reversed(ephemeral_files)): - ephemeral_file.seek(0) - assert ephemeral_file.read() == f"interceptor: {i + 1}".encode() - - ephemeral_file.close() + ephemeral_file.close() async def server_streaming_not_httpstatus_200(scope: Scope, receive: Receive, send: Send) -> None: @@ -520,12 +519,11 @@ async def test_server_streaming_not_httpstatus_200(hypercorn_server: ServerConfi ping_request = StreamRequest(messages=PingRequest(name="Bob")) with pytest.raises(ConnectError) as excinfo: - await client.call_server_stream(ping_request) - - assert excinfo.value.code == Code.UNAVAILABLE - assert len(excinfo.value.details) == 0 - assert excinfo.value.wire_error is False - assert excinfo.value.metadata == {} + async with client.call_server_stream(ping_request): + assert excinfo.value.code == Code.UNAVAILABLE + assert len(excinfo.value.details) == 0 + assert excinfo.value.wire_error is False + assert excinfo.value.metadata == {} async def client_streaming(scope: Scope, receive: Receive, send: Send) -> None: @@ -582,11 +580,11 @@ async def iterator() -> AsyncIterator[PingRequest]: client = Client(session=session, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(messages=iterator()) - response = await client.call_client_stream(ping_request) - want = ["I'm fine."] - async for message in response.messages: - assert message.name in want - want.remove(message.name) + async with client.call_client_stream(ping_request) as response: + want = ["I'm fine."] + async for message in response.messages: + assert message.name in want + want.remove(message.name) @pytest.mark.asyncio() @@ -645,11 +643,10 @@ async def iterator() -> AsyncIterator[PingRequest]: ping_request = StreamRequest(messages=iterator()) - await client.call_client_stream(ping_request) - - assert len(ephemeral_files) == 2 - for i, ephemeral_file in enumerate(reversed(ephemeral_files)): - ephemeral_file.seek(0) - assert ephemeral_file.read() == f"interceptor: {i + 1}".encode() + async with client.call_client_stream(ping_request): + assert len(ephemeral_files) == 2 + for i, ephemeral_file in enumerate(reversed(ephemeral_files)): + ephemeral_file.seek(0) + assert ephemeral_file.read() == f"interceptor: {i + 1}".encode() - ephemeral_file.close() + ephemeral_file.close() From 02ca9bc519e564094945a903dadfd8f5ab72dbe5 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Fri, 18 Apr 2025 21:55:25 +0900 Subject: [PATCH 07/16] all: update doc --- src/connect/connect.py | 22 ++++++++++---- src/connect/protocol_connect.py | 53 +++++++++++++++++++++++---------- src/connect/utils.py | 50 ++++++++++++++++++++++--------- 3 files changed, 91 insertions(+), 34 deletions(-) diff --git a/src/connect/connect.py b/src/connect/connect.py index 1104665..3bf766f 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -170,6 +170,7 @@ def __init__( headers (Mapping[str, str]): The request headers. method (str): The HTTP method used for the request. timeout (float): The timeout for the request. + abort_event (asyncio.Event): An event to signal request abortion. Returns: None @@ -221,6 +222,7 @@ def __init__( headers (Mapping[str, str]): The request headers. method (str): The HTTP method used for the request. timeout (float): The timeout for the request. + abort_event (asyncio.Event): An event to signal request abortion. Returns: None @@ -785,15 +787,25 @@ async def recieve_unary_response[T](conn: UnaryClientConn, t: type[T]) -> UnaryR async def recieve_stream_response[T]( conn: StreamingClientConn, t: type[T], spec: Spec, abort_event: asyncio.Event | None ) -> StreamResponse[T]: - """Receive a stream response from a streaming client connection. + """Handle the reception of a stream response based on the specified stream type. + + For `ClientStream` type, ensures that exactly one message is received. If no message + or multiple messages are received, raises a `ConnectError`. For other stream types, + returns a stream response wrapping the received messages. Args: - conn (StreamingClientConn): The streaming client connection. - t (type[T]): The type of the response to be received. - spec (Spec): The specification for the request. + conn (StreamingClientConn): The streaming connection used to receive messages. + t (type[T]): The expected type of the messages in the stream. + spec (Spec): The specification of the stream, including its type. + abort_event (asyncio.Event | None): An optional event to signal abortion of the stream. Returns: - StreamResponse[T]: The stream response containing the received data, response headers, and response trailers. + StreamResponse[T]: A stream response containing the received messages, response headers, + and response trailers. + + Raises: + ConnectError: If the stream type is `ClientStream` and no message or multiple messages + are received. """ if spec.stream_type == StreamType.ClientStream: diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index bb96c72..df3beba 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -1422,6 +1422,15 @@ def end_stream_error(self) -> ConnectError | None: return self._end_stream_error async def aclose(self) -> None: + """Asynchronously closes the stream if it has an `aclose` method. + + This method checks if the `self.stream` object has an asynchronous + `aclose` method. If the method exists, it is invoked to close the stream. + + Returns: + None + + """ aclose = get_acallable_attribute(self.stream, "aclose") if aclose: await aclose() @@ -1792,21 +1801,25 @@ async def receive(self, message: Any, abort_event: asyncio.Event | None = None) async def send( self, messages: AsyncIterator[Any], timeout: float | None, abort_event: asyncio.Event | None ) -> None: - """Send a series of messages asynchronously. - - This method marshals the provided messages, constructs an HTTP POST request, - and sends it using the httpcore library. It also triggers any registered - request and response hooks, and validates the response. + """Send an asynchronous HTTP POST request with the given messages and handle the response. Args: messages (AsyncIterator[Any]): An asynchronous iterator of messages to be sent. - timeout (float | None): The timeout for the request in seconds. - - Returns: - None + timeout (float | None): Optional timeout value in seconds for the request. If provided, + it sets the read timeout for the request. + abort_event (asyncio.Event | None): Optional asyncio event that, if set, will abort the request. Raises: - Exception: If there is an error during the request or response handling. + ConnectError: If the request is aborted or if there is an error during the request. + + Hooks: + - Executes hooks registered in `self._event_hooks["request"]` before sending the request. + - Executes hooks registered in `self._event_hooks["response"]` after receiving the response. + + Notes: + - If `abort_event` is provided and set during the request, the request will be canceled, + and a `ConnectError` with code `Code.CANCELED` will be raised. + - The response stream is unmarshaled and validated after the request is completed. """ if abort_event and abort_event.is_set(): @@ -1909,6 +1922,12 @@ async def _validate_response(self, response: httpcore.Response) -> None: self._response_headers.update(response_headers) async def aclose(self) -> None: + """Asynchronously closes the connection by invoking the `aclose` method of the unmarshaler. + + Returns: + None + + """ await self.unmarshaler.aclose() @@ -2043,17 +2062,21 @@ def on_request_send(self, fn: EventHook) -> None: self._event_hooks["request"].append(fn) async def send(self, message: Any, timeout: float | None, abort_event: asyncio.Event | None) -> bytes: - """Send a message asynchronously and returns the marshaled data. + """Send a message asynchronously using the specified HTTP method and handles the response. Args: - message (Any): The message to be sent. - timeout (float | None): The timeout for the request in seconds. + message (Any): The message to be sent, which will be marshaled before sending. + timeout (float | None): The timeout for the request in seconds. If provided, it will be + included in the request headers and extensions. + abort_event (asyncio.Event | None): An optional asyncio event that can be used to abort + the request. If the event is set, the request will be canceled. Returns: - bytes: The marshaled data of the message. + bytes: The marshaled data of the message that was sent. Raises: - Exception: If the response validation fails. + ConnectError: If the request is aborted or if there are issues during the request/response + lifecycle. """ if abort_event and abort_event.is_set(): diff --git a/src/connect/utils.py b/src/connect/utils.py index dbf0219..c30a92e 100644 --- a/src/connect/utils.py +++ b/src/connect/utils.py @@ -97,6 +97,17 @@ def get_callable_attribute(obj: object, attr: str) -> typing.Callable[..., typin def get_acallable_attribute(obj: object, attr: str) -> typing.Callable[..., typing.Awaitable[typing.Any]] | None: + """Retrieve an attribute from an object if it is both callable and asynchronous. + + Args: + obj (object): The object from which to retrieve the attribute. + attr (str): The name of the attribute to retrieve. + + Returns: + typing.Callable[..., typing.Awaitable[typing.Any]] | None: + The attribute if it is callable and asynchronous, otherwise None. + + """ try: attr_value = getattr(obj, attr) if callable(attr_value) and is_async_callable(attr_value): @@ -206,16 +217,20 @@ def __init__(self) -> None: class AsyncIteratorByteStream[T]: - """An asynchronous iterator for byte streams. + """An asynchronous iterator for streaming data of type `T`. - This class wraps an asynchronous iterable of bytes and provides an - asynchronous iterator interface. It ensures that the stream is only - consumed once and provides a method to close the stream if it supports - asynchronous closing. + This class wraps an asynchronous iterable and provides functionality to + ensure that the stream is consumed only once. It also supports an optional + cleanup function to be called when the stream is closed. + + Type Parameters: + T: The type of the items in the asynchronous iterable. Attributes: - _stream (typing.AsyncIterable[bytes]): The asynchronous iterable byte stream. + _stream (typing.AsyncIterable[T]): The asynchronous iterable to be streamed. _is_stream_consumed (bool): A flag indicating whether the stream has been consumed. + aclose_func (Callable[..., Awaitable[None]] | None): An optional asynchronous + cleanup function to be called when the stream is closed. """ @@ -226,10 +241,12 @@ class AsyncIteratorByteStream[T]: def __init__( self, stream: typing.AsyncIterable[T], aclose_func: Callable[..., Awaitable[None]] | None = None ) -> None: - """Initialize the utility with an asynchronous byte stream. + """Initialize an instance of the class. Args: - stream (typing.AsyncIterable[bytes]): An asynchronous iterable that yields bytes. + stream (typing.AsyncIterable[T]): An asynchronous iterable representing the stream of data. + aclose_func (Callable[..., Awaitable[None]] | None, optional): + A callable function that is awaited to close the stream. Defaults to None. """ self._stream = stream @@ -237,13 +254,19 @@ def __init__( self.aclose_func = aclose_func async def __aiter__(self) -> typing.AsyncIterator[T]: - """Asynchronously iterates over the stream and yields parts of it. + """Asynchronously iterates over the elements of the stream. + + This method allows the object to be used as an asynchronous iterator. + It ensures that the stream is not consumed multiple times and properly + handles cleanup in case of exceptions. Yields: - bytes: Parts of the stream. + T: The next element in the asynchronous stream. Raises: StreamConsumedError: If the stream has already been consumed. + BaseException: Propagates any exception raised during iteration + after ensuring the stream is closed. """ if self._is_stream_consumed: @@ -258,11 +281,10 @@ async def __aiter__(self) -> typing.AsyncIterator[T]: raise exc async def aclose(self) -> None: - """Asynchronously closes the stream if it has an `aclose` method. + """Asynchronously closes resources if an asynchronous close function is provided. - This method checks if the `_stream` attribute has an `aclose` method and - calls it asynchronously to close the stream. If the `_stream` does not - have an `aclose` method, this method does nothing. + This method checks if an `aclose_func` is defined. If it is, the function + is awaited to perform any necessary cleanup or resource deallocation. Returns: None From 427a3a8113f9b169bb7771b5b4ee6d2ea57c0580 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Fri, 18 Apr 2025 22:39:08 +0900 Subject: [PATCH 08/16] connect: fix validation for single stream --- conformance/run-testcase.txt | 2 + src/connect/connect.py | 86 +++++++++++++++++++++++------------- src/connect/utils.py | 2 +- 3 files changed, 59 insertions(+), 31 deletions(-) diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt index e69de29..0fb9dec 100644 --- a/conformance/run-testcase.txt +++ b/conformance/run-testcase.txt @@ -0,0 +1,2 @@ +Connect Unexpected Responses/HTTPVersion:2/TLS:true/client-stream/multiple-responses +Connect Unexpected Responses/HTTPVersion:2/TLS:false/client-stream/ok-but-no-response diff --git a/src/connect/connect.py b/src/connect/connect.py index 3bf766f..f283ebc 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -2,7 +2,7 @@ import abc import asyncio -from collections.abc import AsyncIterable, AsyncIterator, Callable, Mapping +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Mapping from enum import Enum from http import HTTPMethod from typing import Any, Protocol, cast @@ -13,7 +13,7 @@ from connect.error import ConnectError from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel -from connect.utils import AsyncIteratorByteStream, aiterate, get_callable_attribute +from connect.utils import AsyncIteratorStream, aiterate, get_callable_attribute class StreamType(Enum): @@ -784,52 +784,78 @@ async def recieve_unary_response[T](conn: UnaryClientConn, t: type[T]) -> UnaryR return UnaryResponse(message, conn.response_headers, conn.response_trailers) +async def _receive_exactly_one[T](stream: AsyncIterator[T], aclose: Callable[[], Awaitable[None]]) -> T: + """Asynchronously receives exactly one item from an asynchronous iterator. + + This function ensures that the provided asynchronous iterator (`stream`) yields + exactly one item. If the iterator yields no items or more than one item, a + `ConnectError` is raised. The provided `aclose` callable is always invoked to + close the stream, regardless of success or failure. + + Type Parameters: + T: The type of the items in the asynchronous iterator. + + Args: + stream (AsyncIterator[T]): The asynchronous iterator to consume. + aclose (Callable[[], Awaitable[None]]): A callable that closes the stream + when invoked. + + Returns: + T: The single item yielded by the asynchronous iterator. + + Raises: + ConnectError: If the iterator yields no items or more than one item. + + """ + try: + first = await stream.__anext__() + try: + await stream.__anext__() + raise ConnectError( + "ClientStream should only receive one message, but received multiple.", Code.UNIMPLEMENTED + ) + except StopAsyncIteration: + return first + except StopAsyncIteration: + raise ConnectError("ClientStream should receive one message, but received none.", Code.UNIMPLEMENTED) from None + finally: + await aclose() + + async def recieve_stream_response[T]( conn: StreamingClientConn, t: type[T], spec: Spec, abort_event: asyncio.Event | None ) -> StreamResponse[T]: - """Handle the reception of a stream response based on the specified stream type. - - For `ClientStream` type, ensures that exactly one message is received. If no message - or multiple messages are received, raises a `ConnectError`. For other stream types, - returns a stream response wrapping the received messages. + """Handle receiving a stream response from a streaming client connection. Args: - conn (StreamingClientConn): The streaming connection used to receive messages. - t (type[T]): The expected type of the messages in the stream. + conn (StreamingClientConn): The streaming client connection used to receive the stream. + t (type[T]): The type of the messages expected in the stream. spec (Spec): The specification of the stream, including its type. abort_event (asyncio.Event | None): An optional event to signal abortion of the stream. Returns: - StreamResponse[T]: A stream response containing the received messages, response headers, + StreamResponse[T]: A response object containing the received stream, response headers, and response trailers. Raises: - ConnectError: If the stream type is `ClientStream` and no message or multiple messages - are received. + Any exceptions raised during the reception of the stream or processing of the messages. - """ - if spec.stream_type == StreamType.ClientStream: - count = 0 - single_message: T | None = None - async for message in conn.receive(t, abort_event): - single_message = message - count += 1 + Notes: + - If the stream type is `StreamType.ClientStream`, it expects exactly one message + and wraps it in a single-message stream. + - For other stream types, it directly returns the received stream. - if single_message is None: - raise ConnectError("ClientStream should receive one message, but received none.", Code.UNIMPLEMENTED) + """ + receive_stream = AsyncIteratorStream[T](conn.receive(t, abort_event), conn.aclose) - if count > 1: - raise ConnectError( - "ClientStream should only receive one message, but received multiple.", Code.UNIMPLEMENTED - ) + if spec.stream_type == StreamType.ClientStream: + single_message = await _receive_exactly_one(receive_stream.__aiter__(), receive_stream.aclose) - return StreamResponse(aiterate([single_message]), conn.response_headers, conn.response_trailers) - else: return StreamResponse( - AsyncIteratorByteStream[T](conn.receive(t, abort_event), conn.aclose), - conn.response_headers, - conn.response_trailers, + AsyncIteratorStream[T](aiterate([single_message])), conn.response_headers, conn.response_trailers ) + else: + return StreamResponse(receive_stream, conn.response_headers, conn.response_trailers) async def receive_unary_message[T](conn: ReceiveConn, t: type[T]) -> T: diff --git a/src/connect/utils.py b/src/connect/utils.py index c30a92e..172ad50 100644 --- a/src/connect/utils.py +++ b/src/connect/utils.py @@ -216,7 +216,7 @@ def __init__(self) -> None: super().__init__("Stream has already been consumed.") -class AsyncIteratorByteStream[T]: +class AsyncIteratorStream[T]: """An asynchronous iterator for streaming data of type `T`. This class wraps an asynchronous iterable and provides functionality to From 740fd990aebdc207674cccba60aa44f2c0b56443 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 19 Apr 2025 10:38:31 +0900 Subject: [PATCH 09/16] conformance: fix testcase --- conformance/client_known_failing.yaml | 8 +------- conformance/run-testcase-tmp.txt | 6 ------ 2 files changed, 1 insertion(+), 13 deletions(-) delete mode 100644 conformance/run-testcase-tmp.txt diff --git a/conformance/client_known_failing.yaml b/conformance/client_known_failing.yaml index cf66088..8b13789 100644 --- a/conformance/client_known_failing.yaml +++ b/conformance/client_known_failing.yaml @@ -1,7 +1 @@ -# Cancellation is not supported yet -Client Cancellation/**/bidi-stream/half-duplex/cancel-after-close-send -Client Cancellation/**/bidi-stream/half-duplex/cancel-before-close-send -Client Cancellation/**/bidi-stream/half-duplex/cancel-after-responses -Client Cancellation/**/bidi-stream/full-duplex/cancel-after-close-send -Client Cancellation/**/bidi-stream/full-duplex/cancel-before-close-send -Client Cancellation/**/bidi-stream/full-duplex/cancel-after-responses + diff --git a/conformance/run-testcase-tmp.txt b/conformance/run-testcase-tmp.txt deleted file mode 100644 index bcd9ca2..0000000 --- a/conformance/run-testcase-tmp.txt +++ /dev/null @@ -1,6 +0,0 @@ -Client Cancellation/**/unary/cancel-after-close-send -Client Cancellation/**/client-stream/cancel-before-close-send -Client Cancellation/**/client-stream/cancel-after-close-send -Client Cancellation/**/server-stream/cancel-after-close-send -Client Cancellation/**/server-stream/cancel-after-responses -Client Cancellation/HTTPVersion:2/Protocol:PROTOCOL_CONNECT/Codec:CODEC_PROTO/Compression:COMPRESSION_GZIP/TLS:false/bidi-stream/full-duplex/cancel-after-responses From 047ee5fb6f3244367355578526906e957c33bcf2 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 19 Apr 2025 10:58:30 +0900 Subject: [PATCH 10/16] client_runner: fix checking cancel field --- conformance/client_runner.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/conformance/client_runner.py b/conformance/client_runner.py index a32241c..0aff6d5 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -218,7 +218,7 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c abort_event = asyncio.Event() req = await anext(reqs) - if msg.HasField("cancel") and msg.cancel.HasField("after_close_send_ms"): + if msg.cancel.after_close_send_ms > 0: async def delayed_abort() -> None: await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) @@ -267,9 +267,10 @@ async def _reqs() -> AsyncGenerator[service_pb2.ClientStreamRequest]: await asyncio.sleep(msg.request_delay_ms / 1000.0) yield req - if msg.HasField("cancel") and msg.cancel.HasField("before_close_send"): + if msg.cancel.HasField("before_close_send"): abort_event.set() - elif msg.HasField("cancel") and msg.cancel.HasField("after_close_send_ms"): + + if msg.cancel.HasField("after_close_send_ms"): async def delayed_abort() -> None: await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) @@ -311,7 +312,7 @@ async def delayed_abort() -> None: messages=reqs, headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event ), ) as resp: - if msg.HasField("cancel") and msg.cancel.HasField("after_close_send_ms"): + if msg.cancel.HasField("after_close_send_ms"): async def delayed_abort() -> None: await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) @@ -354,10 +355,10 @@ async def delayed_abort() -> None: messages=reqs, headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event ), ) as resp: - if msg.HasField("cancel") and msg.cancel.HasField("before_close_send"): + if msg.cancel.HasField("before_close_send"): abort_event.set() - if msg.HasField("cancel") and msg.cancel.HasField("after_close_send_ms"): + if msg.cancel.HasField("after_close_send_ms"): async def delayed_abort() -> None: await asyncio.sleep(msg.cancel.after_close_send_ms / 1000) @@ -365,11 +366,7 @@ async def delayed_abort() -> None: asyncio.create_task(delayed_abort()) - if ( - msg.HasField("cancel") - and msg.cancel.HasField("after_num_responses") - and msg.cancel.after_num_responses == 0 - ): + if msg.cancel.HasField("after_num_responses") and msg.cancel.after_num_responses == 0: abort_event.set() async for message in resp.messages: From 7fbc42257a26f9975f7defdb1367d36f92b13f4e Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 19 Apr 2025 11:13:59 +0900 Subject: [PATCH 11/16] client_runner: fix float value --- conformance/client_config.yaml | 2 +- conformance/client_runner.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/conformance/client_config.yaml b/conformance/client_config.yaml index ce19d9b..22fb4f4 100644 --- a/conformance/client_config.yaml +++ b/conformance/client_config.yaml @@ -22,4 +22,4 @@ features: supports_trailers: true supports_half_duplex_bidi_over_http1: true supports_connect_get: true - supports_message_receive_limit: false + supports_message_receive_limit: true diff --git a/conformance/client_runner.py b/conformance/client_runner.py index 0aff6d5..9934fbb 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -213,7 +213,7 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c client = service_connect.ConformanceServiceClient(base_url=url, session=session, options=options) if msg.stream_type == config_pb2.STREAM_TYPE_UNARY: if msg.request_delay_ms > 0: - await asyncio.sleep(msg.request_delay_ms / 1000.0) + await asyncio.sleep(msg.request_delay_ms / 1000) abort_event = asyncio.Event() req = await anext(reqs) @@ -264,7 +264,7 @@ async def delayed_abort() -> None: async def _reqs() -> AsyncGenerator[service_pb2.ClientStreamRequest]: async for req in reqs: if msg.request_delay_ms > 0: - await asyncio.sleep(msg.request_delay_ms / 1000.0) + await asyncio.sleep(msg.request_delay_ms / 1000) yield req if msg.cancel.HasField("before_close_send"): @@ -298,7 +298,7 @@ async def delayed_abort() -> None: elif msg.stream_type == config_pb2.STREAM_TYPE_SERVER_STREAM: abort_event = asyncio.Event() if msg.request_delay_ms > 0: - await asyncio.sleep(msg.request_delay_ms / 1000.0) + await asyncio.sleep(msg.request_delay_ms / 1000) header = Headers() for h in msg.request_headers: @@ -340,8 +340,12 @@ async def delayed_abort() -> None: or msg.stream_type == config_pb2.STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM ): abort_event = asyncio.Event() - if msg.request_delay_ms > 0: - await asyncio.sleep(msg.request_delay_ms / 1000.0) + + async def _reqs() -> AsyncGenerator[service_pb2.ClientStreamRequest]: + async for req in reqs: + if msg.request_delay_ms > 0: + await asyncio.sleep(msg.request_delay_ms / 1000) + yield req header = Headers() for h in msg.request_headers: @@ -352,7 +356,7 @@ async def delayed_abort() -> None: async with getattr(client, msg.method)( StreamRequest( - messages=reqs, headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event + messages=_reqs(), headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event ), ) as resp: if msg.cancel.HasField("before_close_send"): @@ -429,6 +433,8 @@ async def read_requests() -> None: """Read requests from standard input and process them asynchronously.""" loop = asyncio.get_event_loop() while req := await loop.run_in_executor(None, read_request): + await asyncio.sleep(0.01) + loop.create_task(run_message(req)) asyncio.run(read_requests()) From 1c82eb18b732fd46bfdac1e816a03e15e04f5d71 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 19 Apr 2025 11:20:25 +0900 Subject: [PATCH 12/16] client_runner: fix create connect headers --- conformance/client_runner.py | 47 +++++++++++++++--------------------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/conformance/client_runner.py b/conformance/client_runner.py index 9934fbb..b5203a8 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -141,6 +141,17 @@ def to_pb_headers(headers: Headers) -> list[service_pb2.Header]: ] +def to_connect_headers(pb_headers: RepeatedCompositeFieldContainer[service_pb2.Header]) -> Headers: + headers = Headers() + for h in pb_headers: + if key := headers.get(h.name.lower()): + headers[key] = f"{headers[key]}, {', '.join(h.value)}" + else: + headers[h.name.lower()] = ", ".join(h.value) + + return headers + + async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_compat_pb2.ClientCompatResponse: """Handle a client compatibility request and returns a response. @@ -226,17 +237,12 @@ async def delayed_abort() -> None: asyncio.create_task(delayed_abort()) - header = Headers() - for h in msg.request_headers: - if key := header.get(h.name.lower()): - header[key] = f"{header[key]}, {', '.join(h.value)}" - else: - header[h.name.lower()] = ", ".join(h.value) + headers = to_connect_headers(msg.request_headers) resp = await getattr(client, msg.method)( UnaryRequest( message=req, - headers=header, + headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event, ), @@ -254,12 +260,7 @@ async def delayed_abort() -> None: ) elif msg.stream_type == config_pb2.STREAM_TYPE_CLIENT_STREAM: abort_event = asyncio.Event() - header = Headers() - for h in msg.request_headers: - if key := header.get(h.name.lower()): - header[key] = f"{header[key]}, {', '.join(h.value)}" - else: - header[h.name.lower()] = ", ".join(h.value) + headers = to_connect_headers(msg.request_headers) async def _reqs() -> AsyncGenerator[service_pb2.ClientStreamRequest]: async for req in reqs: @@ -280,7 +281,7 @@ async def delayed_abort() -> None: async with getattr(client, msg.method)( StreamRequest( - messages=_reqs(), headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event + messages=_reqs(), headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event ), ) as resp: async for message in resp.messages: @@ -300,16 +301,11 @@ async def delayed_abort() -> None: if msg.request_delay_ms > 0: await asyncio.sleep(msg.request_delay_ms / 1000) - header = Headers() - for h in msg.request_headers: - if key := header.get(h.name.lower()): - header[key] = f"{header[key]}, {', '.join(h.value)}" - else: - header[h.name.lower()] = ", ".join(h.value) + headers = to_connect_headers(msg.request_headers) async with getattr(client, msg.method)( StreamRequest( - messages=reqs, headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event + messages=reqs, headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event ), ) as resp: if msg.cancel.HasField("after_close_send_ms"): @@ -347,16 +343,11 @@ async def _reqs() -> AsyncGenerator[service_pb2.ClientStreamRequest]: await asyncio.sleep(msg.request_delay_ms / 1000) yield req - header = Headers() - for h in msg.request_headers: - if key := header.get(h.name.lower()): - header[key] = f"{header[key]}, {', '.join(h.value)}" - else: - header[h.name.lower()] = ", ".join(h.value) + headers = to_connect_headers(msg.request_headers) async with getattr(client, msg.method)( StreamRequest( - messages=_reqs(), headers=header, timeout=msg.timeout_ms / 1000, abort_event=abort_event + messages=_reqs(), headers=headers, timeout=msg.timeout_ms / 1000, abort_event=abort_event ), ) as resp: if msg.cancel.HasField("before_close_send"): From c666df58cd00a07ed79c90346aff94acbff6bbad Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 19 Apr 2025 11:42:34 +0900 Subject: [PATCH 13/16] connect: AsyncIterator to AsyncIterable --- src/connect/connect.py | 17 ++++++++--------- src/connect/protocol_connect.py | 10 +++++----- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/connect/connect.py b/src/connect/connect.py index f283ebc..8535ade 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -75,7 +75,6 @@ def __init__( """Initialize a new Request instance. Args: - messages (AsyncIterator[T]): An asynchronous iterator of messages. spec (Spec): The specification for the request. peer (Peer): The peer information. headers (Mapping[str, str]): The request headers. @@ -139,7 +138,7 @@ class StreamRequest[T](RequestCommon): """StreamRequest class represents a request that can handle streaming messages. Attributes: - messages (AsyncIterator[T]): An asynchronous iterator of messages. + messages (AsyncIterable[T]): An asynchronous iterable of messages. _spec (Spec): The specification for the request. _peer (Peer): The peer information. _headers (Headers): The request headers. @@ -147,13 +146,13 @@ class StreamRequest[T](RequestCommon): """ - _messages: AsyncIterator[T] + _messages: AsyncIterable[T] timeout: float | None abort_event: asyncio.Event | None = None def __init__( self, - messages: AsyncIterator[T] | T, + messages: AsyncIterable[T] | T, spec: Spec | None = None, peer: Peer | None = None, headers: Headers | None = None, @@ -164,7 +163,7 @@ def __init__( """Initialize a new Request instance. Args: - messages (AsyncIterator[T]): An asynchronous iterator of messages. + messages (AsyncIterable[T] | T): The request messages. spec (Spec): The specification for the request. peer (Peer): The peer information. headers (Mapping[str, str]): The request headers. @@ -177,12 +176,12 @@ def __init__( """ super().__init__(spec, peer, headers, method) - self._messages = messages if isinstance(messages, AsyncIterator) else aiterate([messages]) + self._messages = messages if isinstance(messages, AsyncIterable) else aiterate([messages]) self.timeout = timeout self.abort_event = abort_event @property - def messages(self) -> AsyncIterator[T]: + def messages(self) -> AsyncIterable[T]: """Return the request message.""" return self._messages @@ -493,7 +492,7 @@ async def send(self, messages: AsyncIterable[Any]) -> None: """Send a stream of messages asynchronously. Args: - messages (AsyncIterator[Any]): An asynchronous iterator that yields messages to be sent. + messages (AsyncIterable[Any]): The messages to be sent. Raises: NotImplementedError: This method should be implemented by subclasses. @@ -620,7 +619,7 @@ def request_headers(self) -> Headers: @abc.abstractmethod async def send( - self, messages: AsyncIterator[Any], timeout: float | None, abort_event: asyncio.Event | None + self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None ) -> None: """Send a stream of messages.""" raise NotImplementedError() diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index df3beba..32ca3f0 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -1194,7 +1194,7 @@ async def marshal(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: """Asynchronously marshals and compresses messages from an asynchronous iterator. Args: - messages (AsyncIterator[Any]): An asynchronous iterator of messages to be marshaled. + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled. Yields: AsyncIterator[bytes]: An asynchronous iterator of marshaled and optionally compressed messages in bytes. @@ -1290,7 +1290,7 @@ class ConnectStreamingUnmarshaler: Attributes: codec (Codec): The codec used for unmarshaling data. compression (Compression | None): The compression method used, if any. - stream (AsyncIteratorByteStream | None): The asynchronous byte stream to read data from. + stream (AsyncIterable[bytes] | None): The asynchronous byte stream to read from. buffer (bytes): The buffer to store incoming data chunks. """ @@ -1548,7 +1548,7 @@ async def send(self, messages: AsyncIterable[Any]) -> None: converts it to a JSON object, and sends it as the final message in the stream. Args: - messages (AsyncIterator[Any]): An asynchronous iterator of messages to be sent. + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be sent. Returns: None @@ -1799,12 +1799,12 @@ async def receive(self, message: Any, abort_event: asyncio.Event | None = None) raise ConnectError("missing end stream message", Code.INVALID_ARGUMENT) async def send( - self, messages: AsyncIterator[Any], timeout: float | None, abort_event: asyncio.Event | None + self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None ) -> None: """Send an asynchronous HTTP POST request with the given messages and handle the response. Args: - messages (AsyncIterator[Any]): An asynchronous iterator of messages to be sent. + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be sent. timeout (float | None): Optional timeout value in seconds for the request. If provided, it sets the read timeout for the request. abort_event (asyncio.Event | None): Optional asyncio event that, if set, will abort the request. From 7472f5cff5578fd7f5733ffae363589064a1ae33 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 19 Apr 2025 11:51:31 +0900 Subject: [PATCH 14/16] connect: rename AsyncDataStream --- src/connect/connect.py | 6 +++--- src/connect/protocol_connect.py | 6 +++--- src/connect/utils.py | 27 ++++++++++++++++----------- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/connect/connect.py b/src/connect/connect.py index 8535ade..c9aeb49 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -13,7 +13,7 @@ from connect.error import ConnectError from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel -from connect.utils import AsyncIteratorStream, aiterate, get_callable_attribute +from connect.utils import AsyncDataStream, aiterate, get_callable_attribute class StreamType(Enum): @@ -845,13 +845,13 @@ async def recieve_stream_response[T]( - For other stream types, it directly returns the received stream. """ - receive_stream = AsyncIteratorStream[T](conn.receive(t, abort_event), conn.aclose) + receive_stream = AsyncDataStream[T](conn.receive(t, abort_event), conn.aclose) if spec.stream_type == StreamType.ClientStream: single_message = await _receive_exactly_one(receive_stream.__aiter__(), receive_stream.aclose) return StreamResponse( - AsyncIteratorStream[T](aiterate([single_message])), conn.response_headers, conn.response_trailers + AsyncDataStream[T](aiterate([single_message])), conn.response_headers, conn.response_trailers ) else: return StreamResponse(receive_stream, conn.response_headers, conn.response_trailers) diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 32ca3f0..ee3e2d2 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -1115,7 +1115,7 @@ def _write_with_get(self, url: URL) -> None: self.url = url -class ResponseAsyncByteStream(AsyncByteStream): +class HTTPCoreResponseAsyncByteStream(AsyncByteStream): """An asynchronous byte stream for reading and writing byte chunks.""" aiterator: AsyncIterable[bytes] | None @@ -1878,7 +1878,7 @@ async def send( hook(response) assert isinstance(response.stream, AsyncIterable) - self.unmarshaler.stream = ResponseAsyncByteStream(aiterator=response.stream) + self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(aiterator=response.stream) await self._validate_response(response) @@ -2156,7 +2156,7 @@ async def send(self, message: Any, timeout: float | None, abort_event: asyncio.E hook(response) assert isinstance(response.stream, AsyncIterable) - self.unmarshaler.stream = ResponseAsyncByteStream(response.stream) + self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(response.stream) await self._validate_response(response) diff --git a/src/connect/utils.py b/src/connect/utils.py index 172ad50..1bab7b7 100644 --- a/src/connect/utils.py +++ b/src/connect/utils.py @@ -216,21 +216,26 @@ def __init__(self) -> None: super().__init__("Stream has already been consumed.") -class AsyncIteratorStream[T]: - """An asynchronous iterator for streaming data of type `T`. - - This class wraps an asynchronous iterable and provides functionality to - ensure that the stream is consumed only once. It also supports an optional - cleanup function to be called when the stream is closed. +class AsyncDataStream[T]: + """AsyncDataStream is a generic class that provides an asynchronous iterable interface for streaming data. + It ensures that the stream is consumed only once and provides a mechanism for resource cleanup. Type Parameters: - T: The type of the items in the asynchronous iterable. + T: The type of elements in the asynchronous stream. Attributes: - _stream (typing.AsyncIterable[T]): The asynchronous iterable to be streamed. - _is_stream_consumed (bool): A flag indicating whether the stream has been consumed. - aclose_func (Callable[..., Awaitable[None]] | None): An optional asynchronous - cleanup function to be called when the stream is closed. + _stream (typing.AsyncIterable[T]): The asynchronous iterable representing the stream of data. + _is_stream_consumed (bool): A flag indicating whether the stream has already been consumed. + aclose_func (Callable[..., Awaitable[None]] | None): An optional asynchronous callable for closing resources. + + Methods: + __init__(stream: typing.AsyncIterable[T], aclose_func: Callable[..., Awaitable[None]] | None = None) -> None: + Initializes the AsyncDataStream instance with the given stream and optional close function. + __aiter__() -> typing.AsyncIterator[T]: + Asynchronously iterates over the elements of the stream. Ensures the stream is consumed only once + and handles cleanup in case of exceptions. + aclose() -> None: + Asynchronously closes resources if an asynchronous close function is provided. """ From 4269673fd656c23a93b19a1d26257a58b0c5f382 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 19 Apr 2025 11:53:57 +0900 Subject: [PATCH 15/16] connect: remove ignore lint --- src/connect/connect.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/connect/connect.py b/src/connect/connect.py index c9aeb49..5584561 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -13,7 +13,7 @@ from connect.error import ConnectError from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel -from connect.utils import AsyncDataStream, aiterate, get_callable_attribute +from connect.utils import AsyncDataStream, aiterate, get_acallable_attribute, get_callable_attribute class StreamType(Enum): @@ -313,8 +313,9 @@ def messages(self) -> AsyncIterable[T]: async def aclose(self) -> None: """Asynchronously close the response stream.""" - if hasattr(self._messages, "aclose"): - await self._messages.aclose() # type: ignore + aclose = get_acallable_attribute(self._messages, "aclose") + if aclose: + await aclose() class UnaryHandlerConn(abc.ABC): From 713269ce11d84dd1aa5ea561985be2a364dbe2ce Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 19 Apr 2025 12:00:08 +0900 Subject: [PATCH 16/16] connect: update doc --- src/connect/client.py | 58 +++++++++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/src/connect/client.py b/src/connect/client.py index 1ef9e77..b3393d0 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -302,13 +302,23 @@ async def call_unary(self, request: UnaryRequest[T_Request]) -> UnaryResponse[T_ @contextlib.asynccontextmanager async def call_server_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]: - """Asynchronously calls a server streaming RPC (Remote Procedure Call) with the given request. + """Initiate a server-streaming RPC call and returns an asynchronous generator that yields responses from the server. Args: - request (UnaryRequest[T_Request]): The request object containing the data to be sent to the server. + request (StreamRequest[T_Request]): The request object containing the + data to be sent to the server. - Returns: - UnaryResponse[T_Response]: The response object containing the data received from the server. + Yields: + StreamResponse[T_Response]: The response objects received from the server. + + Raises: + Any exceptions that occur during the streaming process. + + Notes: + - This method ensures that the response stream is properly closed + after the generator is exhausted or an exception occurs. + - The type parameters `T_Request` and `T_Response` represent the + request and response types, respectively. """ response = await self._call_stream(StreamType.ServerStream, request) @@ -319,16 +329,22 @@ async def call_server_stream(self, request: StreamRequest[T_Request]) -> AsyncGe @contextlib.asynccontextmanager async def call_client_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]: - """Asynchronously calls a client stream and yields responses. - - This method sends a stream request to the client and asynchronously - iterates over the responses, yielding each response one by one. + """Initiate a client-streaming RPC call and returns an asynchronous generator for streaming responses from the server. Args: - request (StreamRequest[T_Request]): The stream request to be sent. + request (StreamRequest[T_Request]): The request object containing the + client-streaming data to be sent to the server. Yields: - UnaryResponse[T_Response]: The response from the client stream. + StreamResponse[T_Response]: An asynchronous generator that yields + responses from the server. + + Raises: + Any exceptions raised during the streaming call. + + Notes: + - The `response.aclose()` method is called in the `finally` block to + ensure proper cleanup of the response stream. """ response = await self._call_stream(StreamType.ClientStream, request) @@ -339,21 +355,25 @@ async def call_client_stream(self, request: StreamRequest[T_Request]) -> AsyncGe @contextlib.asynccontextmanager async def call_bidi_stream(self, request: StreamRequest[T_Request]) -> AsyncGenerator[StreamResponse[T_Response]]: - """Initiate a bidirectional streaming call. + """Initiate a bidirectional streaming call with the server. - This method establishes a bidirectional stream between the client and the server, - allowing both to send and receive messages asynchronously. + This method sends a stream request to the server and returns an asynchronous + generator that yields stream responses from the server. The connection is + automatically closed when the generator is exhausted or an exception occurs. Args: - request (StreamRequest[T_Request]): The request object containing the stream - of messages to be sent to the server. + request (StreamRequest[T_Request]): The stream request object containing + the data to be sent to the server. - Returns: - StreamResponse[T_Response]: An asynchronous stream response object that - allows receiving messages from the server. + Yields: + StreamResponse[T_Response]: The stream response object received from the server. Raises: - Any exceptions raised during the streaming call will propagate to the caller. + Any exceptions raised during the streaming call. + + Notes: + Ensure to consume the generator properly to avoid resource leaks, as the + connection is closed in the `finally` block. """ response = await self._call_stream(StreamType.BiDiStream, request)