diff --git a/conformance/client_config.yaml b/conformance/client_config.yaml index 22fb4f4..0ad7321 100644 --- a/conformance/client_config.yaml +++ b/conformance/client_config.yaml @@ -4,6 +4,7 @@ features: - HTTP_VERSION_2 protocols: - PROTOCOL_CONNECT + - PROTOCOL_GRPC codecs: - CODEC_PROTO compressions: @@ -22,4 +23,4 @@ features: supports_trailers: true supports_half_duplex_bidi_over_http1: true supports_connect_get: true - supports_message_receive_limit: true + supports_message_receive_limit: false diff --git a/conformance/client_runner.py b/conformance/client_runner.py index b5203a8..6e1fa50 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -218,6 +218,11 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c payloads = [] try: options = ClientOptions() + if msg.protocol == config_pb2.PROTOCOL_GRPC: + options.protocol = "grpc" + if msg.protocol == config_pb2.PROTOCOL_GRPC_WEB: + options.protocol = "grpc-web" + if msg.compression == config_pb2.COMPRESSION_GZIP: options.request_compression_name = "gzip" @@ -424,8 +429,6 @@ async def read_requests() -> None: """Read requests from standard input and process them asynchronously.""" loop = asyncio.get_event_loop() while req := await loop.run_in_executor(None, read_request): - await asyncio.sleep(0.01) - loop.create_task(run_message(req)) asyncio.run(read_requests()) diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt index 26b35a5..a71ce69 100644 --- a/conformance/run-testcase.txt +++ b/conformance/run-testcase.txt @@ -1 +1 @@ -gRPC Unexpected Requests/HTTPVersion:2/TLS:true/unary/multiple-requests +Timeouts/HTTPVersion:2/Protocol:PROTOCOL_GRPC/Codec:CODEC_PROTO/Compression:COMPRESSION_IDENTITY/TLS:true/unary diff --git a/conformance/uv.lock b/conformance/uv.lock index a539222..5f64f6d 100644 --- a/conformance/uv.lock +++ b/conformance/uv.lock @@ -108,7 +108,7 @@ requires-dist = [ { name = "anyio", specifier = ">=4.7.0" }, { name = "googleapis-common-protos", specifier = ">=1.70.0" }, { name = "h2", specifier = ">=4.2.0" }, - { name = "httpcore", specifier = ">=1.0.7" }, + { name = "httpcore", git = "https://github.com/tsubakiky/httpcore" }, { name = "protobuf", specifier = ">=5.29.1" }, { name = "pydantic", specifier = ">=2.10.4" }, { name = "starlette", specifier = ">=0.46.0" }, @@ -179,11 +179,11 @@ wheels = [ [[package]] name = "h11" -version = "0.14.0" +version = "0.16.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418, upload_time = "2022-09-25T15:40:01.519Z" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload_time = "2025-04-24T03:35:25.427Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259, upload_time = "2022-09-25T15:39:59.68Z" }, + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload_time = "2025-04-24T03:35:24.344Z" }, ] [[package]] @@ -210,16 +210,12 @@ wheels = [ [[package]] name = "httpcore" -version = "1.0.7" -source = { registry = "https://pypi.org/simple" } +version = "1.0.9" +source = { git = "https://github.com/tsubakiky/httpcore#e70c821d72d7b9c5634c781cd454ced911052c29" } dependencies = [ { name = "certifi" }, { name = "h11" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6a/41/d7d0a89eb493922c37d343b607bc1b5da7f5be7e383740b4753ad8943e90/httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c", size = 85196, upload_time = "2024-11-15T12:30:47.531Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551, upload_time = "2024-11-15T12:30:45.782Z" }, -] [[package]] name = "hypercorn" diff --git a/pyproject.toml b/pyproject.toml index be7b1d1..dcae5ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "anyio>=4.7.0", "googleapis-common-protos>=1.70.0", "h2>=4.2.0", - "httpcore>=1.0.7", + "httpcore", "protobuf>=5.29.1", "pydantic>=2.10.4", "starlette>=0.46.0", @@ -22,6 +22,9 @@ dependencies = [ "yarl>=1.18.3", ] +[tool.uv.sources] +httpcore = { git = "https://github.com/tsubakiky/httpcore" } + [tool.hatch.build.targets.wheel] packages = ["src/connect"] diff --git a/src/connect/byte_stream.py b/src/connect/byte_stream.py new file mode 100644 index 0000000..5033d9d --- /dev/null +++ b/src/connect/byte_stream.py @@ -0,0 +1,56 @@ +"""Asynchronous byte stream utilities for HTTP core response handling.""" + +from collections.abc import ( + AsyncIterable, + AsyncIterator, +) + +from connect.utils import ( + AsyncByteStream, + get_acallable_attribute, + map_httpcore_exceptions, +) + + +class HTTPCoreResponseAsyncByteStream(AsyncByteStream): + """An asynchronous byte stream for reading and writing byte chunks.""" + + aiterator: AsyncIterable[bytes] | None + _closed: bool + + def __init__( + self, + aiterator: AsyncIterable[bytes] | None = None, + ) -> None: + """Initialize the protocol connect instance. + + Args: + aiterator (AsyncIterable[bytes] | None): An optional asynchronous iterable of bytes. + + Returns: + None + + """ + self.aiterator = aiterator + self._closed = False + + async def __aiter__(self) -> AsyncIterator[bytes]: + """Asynchronous iterator method to read byte chunks from the stream.""" + if self.aiterator: + try: + with map_httpcore_exceptions(): + async for chunk in self.aiterator: + yield chunk + except BaseException as exc: + await self.aclose() + raise exc + + async def aclose(self) -> None: + """Asynchronously close the stream.""" + if not self._closed and self.aiterator: + aclose = get_acallable_attribute(self.aiterator, "aclose") + if not aclose: + return + + with map_httpcore_exceptions(): + await aclose() diff --git a/src/connect/client.py b/src/connect/client.py index b3393d0..913fc26 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -27,9 +27,11 @@ from connect.idempotency_level import IdempotencyLevel from connect.interceptor import apply_interceptors from connect.options import ClientOptions -from connect.protocol import ProtocolClient, ProtocolClientParams +from connect.protocol import Protocol, ProtocolClient, ProtocolClientParams from connect.protocol_connect import ProtocolConnect +from connect.protocol_grpc import ProtocolGRPC from connect.session import AsyncClientSession +from connect.utils import aiterate def parse_request_url(raw_url: str) -> URL: @@ -75,7 +77,7 @@ class ClientConfig: """ url: URL - protocol: ProtocolConnect + protocol: Protocol procedure: str codec: Codec request_compression_name: str | None @@ -113,6 +115,10 @@ def __init__(self, raw_url: str, options: ClientOptions): self.url = url self.protocol = ProtocolConnect() + if options.protocol == "grpc": + self.protocol = ProtocolGRPC(web=False) + elif options.protocol == "grpc-web": + self.protocol = ProtocolGRPC(web=True) self.procedure = proto_path self.codec = ProtoBinaryCodec() self.request_compression_name = options.request_compression_name @@ -228,9 +234,9 @@ def on_request_send(r: httpcore.Request) -> None: conn.on_request_send(on_request_send) - await conn.send(request.message, request.timeout, abort_event=request.abort_event) + await conn.send(aiterate([request.message]), request.timeout, abort_event=request.abort_event) - response = await recieve_unary_response(conn=conn, t=output) + response = await recieve_unary_response(conn=conn, t=output, abort_event=request.abort_event) return response unary_func = apply_interceptors(_unary_func, options.interceptors) @@ -257,7 +263,7 @@ async def call_unary(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Respon return response async def _stream_func(request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: - conn = protocol_client.stream_conn(request.spec, request.headers) + conn = protocol_client.conn(request.spec, request.headers) def on_request_send(r: httpcore.Request) -> None: method = r.method diff --git a/src/connect/connect.py b/src/connect/connect.py index b550b3a..b4fd760 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -318,94 +318,18 @@ async def aclose(self) -> None: await aclose() -class AsyncContentStream[T](AsyncIterable[T]): - """AsyncContentStream is a generic asynchronous stream wrapper for async iterables, providing validation and iteration utilities based on stream type. - - Type Parameters: - T: The type of elements yielded by the asynchronous iterable. - - iterable (AsyncIterable[T]): The asynchronous iterable to wrap. - stream_type (StreamType): The type of stream (e.g., Unary, ServerStream) that determines validation behavior. - - Attributes: - _iterable (AsyncIterable[T]): The underlying asynchronous iterable. - stream_type (StreamType): The type of stream this instance represents. - - """ - - def __init__(self, iterable: AsyncIterable[T], stream_type: StreamType) -> None: - """Initialize a stream wrapper for an async iterable. - - This constructor stores the provided async iterable and its corresponding - stream type for later processing. - - Args: - iterable: An asynchronous iterable containing elements of type T. - stream_type: The type of stream this iterable represents. - - Returns: - None - - """ - self._iterable = iterable - self.stream_type = stream_type - - async def __aiter__(self) -> AsyncIterator[T]: - """Asynchronously iterates over the underlying iterable. - - If single message validation is required, wraps the iterable with a validation step. - Otherwise, yields items directly from the iterable. - - Yields: - T: Items from the underlying asynchronous iterable. - - """ - if self.stream_type == StreamType.Unary or self.stream_type == StreamType.ServerStream: - async for item in validate_single_content_stream(self._iterable): - yield item - else: - async for item in self._iterable: - yield item - - -async def validate_single_content_stream[T](iterable: AsyncIterable[T]) -> AsyncIterator[T]: - """Validate that an asynchronous iterable yields exactly one item. - - This async generator iterates over the provided `iterable` and ensures that it produces exactly one item. - If more than one item is yielded, a `ConnectError` is raised indicating a protocol error. - If no items are yielded, a `ConnectError` is also raised. - - Args: - iterable (AsyncIterable[T]): The asynchronous iterable to validate. - - Yields: - T: The single item from the iterable. - - Raises: - ConnectError: If the iterable yields zero or more than one item. - - """ - count = 0 - async for item in iterable: - if count > 0: - raise ConnectError("protocol error: expected only one message, but got multiple", Code.UNIMPLEMENTED) - - yield item - count += 1 - - if count == 0: - raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) - - -async def ensure_single[T](iterable: AsyncIterable[T]) -> T: +async def ensure_single[T](iterable: AsyncIterable[T], aclose: Callable[[], Awaitable[None]] | None = None) -> T: """Asynchronously ensures that the given async iterable yields exactly one item. Iterates over the provided async iterable (after validating its content stream) and returns the single item if present. Raises a ConnectError if the iterable - is empty or contains more than one item. + is empty or contains more than one item. Optionally closes the iterable by calling + the provided aclose function after processing. Args: iterable (AsyncIterable[T]): An asynchronous iterable expected to yield exactly one item. + aclose (Callable[[], Awaitable[None]] | None, optional): A callable that asynchronously + closes the stream when invoked. If provided, will be called in a finally block. Returns: T: The single item yielded by the iterable. @@ -414,14 +338,20 @@ async def ensure_single[T](iterable: AsyncIterable[T]) -> T: ConnectError: If the iterable yields no items or more than one item. """ - message = None - async for item in validate_single_content_stream(iterable): - message = item - - if message is None: - raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) - - return message + try: + iterator = iterable.__aiter__() + try: + first = await iterator.__anext__() + try: + await iterator.__anext__() + raise ConnectError("protocol error: expected only one message, but got multiple", Code.UNIMPLEMENTED) + except StopAsyncIteration: + return first + except StopAsyncIteration: + raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) from None + finally: + if aclose: + await aclose() class StreamingHandlerConn(abc.ABC): @@ -462,7 +392,7 @@ def peer(self) -> Peer: raise NotImplementedError() @abc.abstractmethod - def receive(self, message: Any) -> AsyncContentStream[Any]: + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message and returns an asynchronous content stream. Args: @@ -694,69 +624,51 @@ async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> S peer information, request headers, and HTTP method. """ - stream = conn.receive(t) - - return StreamRequest( - messages=stream, - spec=conn.spec, - peer=conn.peer, - headers=conn.request_headers, - method=HTTPMethod.POST, - ) - - -async def recieve_unary_response[T](conn: UnaryClientConn, t: type[T]) -> UnaryResponse[T]: - """Receive a unary response from a streaming client connection. - - Args: - conn (StreamingClientConn): The streaming client connection. - t (type[T]): The type of the expected response message. - - Returns: - UnaryResponse[T]: The response containing the message, response headers, and response trailers. - - """ - message = await receive_unary_message(conn, t) - - return UnaryResponse(message, conn.response_headers, conn.response_trailers) - + if conn.spec.stream_type == StreamType.ServerStream: + message = await ensure_single(conn.receive(t)) + + return StreamRequest( + messages=aiterate([message]), + spec=conn.spec, + peer=conn.peer, + headers=conn.request_headers, + method=HTTPMethod.POST.value, + ) + else: + return StreamRequest( + messages=conn.receive(t), + spec=conn.spec, + peer=conn.peer, + headers=conn.request_headers, + method=HTTPMethod.POST.value, + ) -async def _receive_exactly_one[T](stream: AsyncIterator[T], aclose: Callable[[], Awaitable[None]]) -> T: - """Asynchronously receives exactly one item from an asynchronous iterator. - This function ensures that the provided asynchronous iterator (`stream`) yields - exactly one item. If the iterator yields no items or more than one item, a - `ConnectError` is raised. The provided `aclose` callable is always invoked to - close the stream, regardless of success or failure. +async def recieve_unary_response[T]( + conn: StreamingClientConn, t: type[T], abort_event: asyncio.Event | None +) -> UnaryResponse[T]: + """Receives a unary response message from a streaming client connection. - Type Parameters: - T: The type of the items in the asynchronous iterator. + This asynchronous function waits for a unary message of the specified type from the given + streaming client connection. It also handles optional abortion via an asyncio event. + The response, along with any headers and trailers from the connection, is wrapped in a + UnaryResponse object and returned. Args: - stream (AsyncIterator[T]): The asynchronous iterator to consume. - aclose (Callable[[], Awaitable[None]]): A callable that closes the stream - when invoked. + conn (StreamingClientConn): The streaming client connection to receive the message from. + t (type[T]): The expected type of the message to be received. + abort_event (asyncio.Event | None): Optional event to signal abortion of the receive operation. Returns: - T: The single item yielded by the asynchronous iterator. + UnaryResponse[T]: The received message and associated response metadata. Raises: - ConnectError: If the iterator yields no items or more than one item. + Any exceptions raised by `receive_unary_message` or connection errors. """ - try: - first = await stream.__anext__() - try: - await stream.__anext__() - raise ConnectError( - "ClientStream should only receive one message, but received multiple.", Code.UNIMPLEMENTED - ) - except StopAsyncIteration: - return first - except StopAsyncIteration: - raise ConnectError("ClientStream should receive one message, but received none.", Code.UNIMPLEMENTED) from None - finally: - await aclose() + message = await ensure_single(conn.receive(t, abort_event), conn.aclose) + + return UnaryResponse(message, conn.response_headers, conn.response_trailers) async def recieve_stream_response[T]( @@ -783,34 +695,13 @@ async def recieve_stream_response[T]( - For other stream types, it directly returns the received stream. """ - receive_stream = AsyncDataStream[T](conn.receive(t, abort_event), conn.aclose) - if spec.stream_type == StreamType.ClientStream: - single_message = await _receive_exactly_one(receive_stream.__aiter__(), receive_stream.aclose) + single_message = await ensure_single(conn.receive(t, abort_event), conn.aclose) return StreamResponse( AsyncDataStream[T](aiterate([single_message])), conn.response_headers, conn.response_trailers ) else: - return StreamResponse(receive_stream, conn.response_headers, conn.response_trailers) - - -async def receive_unary_message[T](conn: UnaryClientConn, t: type[T]) -> T: - """Asynchronously receives a single unary message from the given connection. - - Args: - conn (UnaryClientConn): The unary client connection to receive the message from. - t (type[T]): The expected type of the message to be received. - - Returns: - T: The received message of type T. - - Raises: - Exception: If receiving the message fails or more than one message is received. - - Note: - This function ensures that exactly one message is received from the connection. - - """ - single_message = await _receive_exactly_one(conn.receive(t), conn.aclose) - return single_message + return StreamResponse( + AsyncDataStream[T](conn.receive(t, abort_event), conn.aclose), conn.response_headers, conn.response_trailers + ) diff --git a/src/connect/envelope.py b/src/connect/envelope.py index c29b490..01bf131 100644 --- a/src/connect/envelope.py +++ b/src/connect/envelope.py @@ -252,6 +252,7 @@ class EnvelopeReader: compression: Compression | None stream: AsyncIterable[bytes] | None buffer: bytes + bytes_read: int last_data: bytes | None def __init__( @@ -275,6 +276,7 @@ def __init__( self.compression = compression self.stream = stream self.buffer = b"" + self.bytes_read = 0 self.last_data = None async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: @@ -299,6 +301,7 @@ async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: async for chunk in self.stream: self.buffer += chunk + self.bytes_read += len(chunk) while True: env, data_len = Envelope.decode(self.buffer) @@ -352,6 +355,8 @@ async def aclose(self) -> None: This method checks if the `self.stream` object has an asynchronous `aclose` method. If the method exists, it is invoked to close the stream. + The bytes_read counter is not reset when closing the stream. + Returns: None diff --git a/src/connect/handler.py b/src/connect/handler.py index 1d462e4..cdb51ce 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -41,7 +41,7 @@ from connect.protocol_connect import ( ProtocolConnect, ) -from connect.protocol_grpc import ProtocolGPRC +from connect.protocol_grpc import ProtocolGRPC from connect.request import Request from connect.response import Response from connect.utils import aiterate @@ -123,7 +123,7 @@ def create_protocol_handlers(config: HandlerConfig) -> list[ProtocolHandler]: list[ProtocolHandler]: A list of initialized protocol handlers. """ - protocols = [ProtocolConnect(), ProtocolGPRC(web=False)] + protocols = [ProtocolConnect(), ProtocolGRPC(web=False)] codecs = CodecMap(config.codecs) @@ -336,6 +336,13 @@ async def stream_handle( except Exception as e: error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) + + if isinstance(e, TimeoutError): + error = ConnectError("the operation timed out", Code.DEADLINE_EXCEEDED) + + if isinstance(e, NotImplementedError): + error = ConnectError("not implemented", Code.UNIMPLEMENTED) + await conn.send_error(error) async def unary_handle( @@ -378,6 +385,7 @@ async def unary_handle( if isinstance(e, NotImplementedError): error = ConnectError("not implemented", Code.UNIMPLEMENTED) + await conn.send_error(error) diff --git a/src/connect/options.py b/src/connect/options.py index 2954648..ef7eba6 100644 --- a/src/connect/options.py +++ b/src/connect/options.py @@ -1,6 +1,6 @@ """Options for the UniversalHandler class.""" -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field @@ -84,6 +84,9 @@ class ClientOptions(BaseModel): enable_get: bool = Field(default=False) """A boolean indicating whether to enable GET requests.""" + protocol: Literal["connect", "grpc", "grpc-web"] = Field(default="connect") + """The protocol to use for the request.""" + def merge(self, override_options: "ClientOptions | None" = None) -> "ClientOptions": """Merge this options object with an override options object. diff --git a/src/connect/protocol.py b/src/connect/protocol.py index 94e3e4f..2bce004 100644 --- a/src/connect/protocol.py +++ b/src/connect/protocol.py @@ -15,7 +15,6 @@ StreamingClientConn, StreamingHandlerConn, StreamType, - UnaryClientConn, ) from connect.error import ConnectError from connect.headers import Headers @@ -97,15 +96,10 @@ def write_request_headers(self, stream_type: StreamType, headers: Headers) -> No raise NotImplementedError() @abc.abstractmethod - def conn(self, spec: Spec, headers: Headers) -> UnaryClientConn: + def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: """Return the connection for the client.""" raise NotImplementedError() - @abc.abstractmethod - def stream_conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: - """Return the streaming connection for the client.""" - raise NotImplementedError() - class ProtocolHandler(abc.ABC): """Abstract base class for handling different protocols.""" diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 580a437..dc963fa 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -20,18 +20,17 @@ from google.protobuf import json_format from yarl import URL +from connect.byte_stream import HTTPCoreResponseAsyncByteStream from connect.code import Code from connect.codec import Codec, CodecNameType, StableCodec from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name from connect.connect import ( Address, - AsyncContentStream, Peer, Spec, StreamingClientConn, StreamingHandlerConn, StreamType, - UnaryClientConn, ensure_single, ) from connect.envelope import EnvelopeFlags, EnvelopeReader, EnvelopeWriter @@ -58,7 +57,6 @@ from connect.session import AsyncClientSession from connect.streaming_response import StreamingResponse from connect.utils import ( - AsyncByteStream, aiterate, get_acallable_attribute, map_httpcore_exceptions, @@ -683,7 +681,7 @@ async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: obj = await self.unmarshaler.unmarshal(message) yield obj - def receive(self, message: Any) -> AsyncContentStream[Any]: + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message, unmarshals it, and returns the resulting object. Args: @@ -693,10 +691,7 @@ def receive(self, message: Any) -> AsyncContentStream[Any]: AsyncIterator[Any]: An async iterator yielding the unmarshaled object. """ - return AsyncContentStream( - self._receive_messages(message), - stream_type=self.spec.stream_type, - ) + return self._receive_messages(message) @property def request_headers(self) -> Headers: @@ -865,7 +860,7 @@ def write_request_headers(self, stream_type: StreamType, headers: Headers) -> No if self.params.compressions: headers[accept_compression_header] = ", ".join(c.name for c in self.params.compressions) - def conn(self, spec: Spec, headers: Headers) -> UnaryClientConn: + def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: """Establish a unary client connection with the given specifications and headers. Args: @@ -876,65 +871,54 @@ def conn(self, spec: Spec, headers: Headers) -> UnaryClientConn: UnaryClientConn: The established unary client connection. """ - conn = ConnectUnaryClientConn( - session=self.params.session, - spec=spec, - peer=self.peer, - url=self.params.url, - compressions=self.params.compressions, - request_headers=headers, - marshaler=ConnectUnaryRequestMarshaler( - connect_marshaler=ConnectUnaryMarshaler( + conn: StreamingClientConn + if spec.stream_type == StreamType.Unary: + conn = ConnectUnaryClientConn( + session=self.params.session, + spec=spec, + peer=self.peer, + url=self.params.url, + compressions=self.params.compressions, + request_headers=headers, + marshaler=ConnectUnaryRequestMarshaler( + connect_marshaler=ConnectUnaryMarshaler( + codec=self.params.codec, + compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + headers=headers, + ) + ), + unmarshaler=ConnectUnaryUnmarshaler( + codec=self.params.codec, + read_max_bytes=self.params.read_max_bytes, + ), + ) + if spec.idempotency_level == IdempotencyLevel.NO_SIDE_EFFECTS: + conn.marshaler.enable_get = self.params.enable_get + conn.marshaler.url = self.params.url + if isinstance(self.params.codec, StableCodec): + conn.marshaler.stable_codec = self.params.codec + else: + conn = ConnectStreamingClientConn( + session=self.params.session, + spec=spec, + peer=self.peer, + url=self.params.url, + codec=self.params.codec, + compressions=self.params.compressions, + request_headers=headers, + marshaler=ConnectStreamingMarshaler( codec=self.params.codec, - compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), compress_min_bytes=self.params.compress_min_bytes, send_max_bytes=self.params.send_max_bytes, - headers=headers, - ) - ), - unmarshaler=ConnectUnaryUnmarshaler( - codec=self.params.codec, - read_max_bytes=self.params.read_max_bytes, - ), - ) - if spec.idempotency_level == IdempotencyLevel.NO_SIDE_EFFECTS: - conn.marshaler.enable_get = self.params.enable_get - conn.marshaler.url = self.params.url - if isinstance(self.params.codec, StableCodec): - conn.marshaler.stable_codec = self.params.codec - - return conn - - def stream_conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: - """Establish a streaming connection using the provided specification and headers. - - Args: - spec (Spec): The specification for the streaming connection. - headers (Headers): The headers to be included in the connection request. - - Returns: - StreamingClientConn: An instance of the streaming client connection. - - """ - conn = ConnectStreamingClientConn( - session=self.params.session, - spec=spec, - peer=self.peer, - url=self.params.url, - codec=self.params.codec, - compressions=self.params.compressions, - request_headers=headers, - marshaler=ConnectStreamingMarshaler( - codec=self.params.codec, - compress_min_bytes=self.params.compress_min_bytes, - send_max_bytes=self.params.send_max_bytes, - compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), - ), - unmarshaler=ConnectStreamingUnmarshaler( - codec=self.params.codec, - read_max_bytes=self.params.read_max_bytes, - ), - ) + compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), + ), + unmarshaler=ConnectStreamingUnmarshaler( + codec=self.params.codec, + read_max_bytes=self.params.read_max_bytes, + ), + ) return conn @@ -1106,50 +1090,6 @@ def _write_with_get(self, url: URL) -> None: self.url = url -class HTTPCoreResponseAsyncByteStream(AsyncByteStream): - """An asynchronous byte stream for reading and writing byte chunks.""" - - aiterator: AsyncIterable[bytes] | None - _closed: bool - - def __init__( - self, - aiterator: AsyncIterable[bytes] | None = None, - ) -> None: - """Initialize the protocol connect instance. - - Args: - aiterator (AsyncIterable[bytes] | None): An optional asynchronous iterable of bytes. - - Returns: - None - - """ - self.aiterator = aiterator - self._closed = False - - async def __aiter__(self) -> AsyncIterator[bytes]: - """Asynchronous iterator method to read byte chunks from the stream.""" - if self.aiterator: - try: - with map_httpcore_exceptions(): - async for chunk in self.aiterator: - yield chunk - except BaseException as exc: - await self.aclose() - raise exc - - async def aclose(self) -> None: - """Asynchronously close the stream.""" - if not self._closed and self.aiterator: - aclose = get_acallable_attribute(self.aiterator, "aclose") - if not aclose: - return - - with map_httpcore_exceptions(): - await aclose() - - class ConnectStreamingMarshaler(EnvelopeWriter): """A class responsible for marshaling messages with optional compression. @@ -1393,7 +1333,7 @@ async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: async for obj, _ in self.unmarshaler.unmarshal(message): yield obj - def receive(self, message: Any) -> AsyncContentStream[Any]: + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message and returns an asynchronous content stream. This method processes the incoming message through the receive_message method @@ -1407,10 +1347,7 @@ def receive(self, message: Any) -> AsyncContentStream[Any]: processed message, configured with the specification's stream type. """ - return AsyncContentStream( - iterable=self._receive_messages(message), - stream_type=self.spec.stream_type, - ) + return self._receive_messages(message) @property def request_headers(self) -> Headers: @@ -1818,7 +1755,7 @@ async def aclose(self) -> None: await self.unmarshaler.aclose() -class ConnectUnaryClientConn(UnaryClientConn): +class ConnectUnaryClientConn(StreamingClientConn): """A client connection for unary RPCs using the Connect protocol. Attributes: @@ -1927,7 +1864,7 @@ async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: obj = await self.unmarshaler.unmarshal(message) yield obj - def receive(self, message: Any) -> AsyncIterator[Any]: + def receive(self, message: Any, _abort_event: asyncio.Event | None) -> AsyncIterator[Any]: """Receives a message and returns an asynchronous iterator over the processed message. Args: @@ -1960,22 +1897,28 @@ def on_request_send(self, fn: EventHook) -> None: """ self._event_hooks["request"].append(fn) - async def send(self, message: Any, timeout: float | None, abort_event: asyncio.Event | None) -> bytes: - """Send a message asynchronously using the specified HTTP method and handles the response. + async def send( + self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None + ) -> None: + """Send a single message asynchronously using either HTTP GET or POST, with support for timeouts and request abortion. Args: - message (Any): The message to be sent, which will be marshaled before sending. - timeout (float | None): The timeout for the request in seconds. If provided, it will be - included in the request headers and extensions. - abort_event (asyncio.Event | None): An optional asyncio event that can be used to abort - the request. If the event is set, the request will be canceled. - - Returns: - bytes: The marshaled data of the message that was sent. + messages (AsyncIterable[Any]): An asynchronous iterable yielding the message(s) to send. Only a single message is allowed. + timeout (float | None): Optional timeout in seconds for the request. If provided, sets a read timeout for the request. + abort_event (asyncio.Event | None): Optional asyncio event that, if set, aborts the request. Raises: - ConnectError: If the request is aborted or if there are issues during the request/response - lifecycle. + ConnectError: If the request is aborted before or during execution, or if other connection errors occur. + + Side Effects: + - Modifies request headers for timeout and content length as needed. + - Invokes registered request and response event hooks. + - Sets the unmarshaler's stream to the response stream for further processing. + - Validates the response after receiving it. + + Notes: + - If `marshaler.enable_get` is True, sends the request as HTTP GET; otherwise, uses HTTP POST. + - Handles cancellation and cleanup if the abort event is triggered during the request. """ if abort_event and abort_event.is_set(): @@ -1986,6 +1929,7 @@ async def send(self, message: Any, timeout: float | None, abort_event: asyncio.E extensions["timeout"] = {"read": timeout} self._request_headers[CONNECT_HEADER_TIMEOUT] = str(int(timeout * 1000)) + message = await ensure_single(messages) data = self.marshaler.marshal(message) if self.marshaler.enable_get: @@ -2059,8 +2003,6 @@ async def send(self, message: Any, timeout: float | None, abort_event: asyncio.E await self._validate_response(response) - return data - @property def response_headers(self) -> Headers: """Return the response headers. diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 90b4b97..a7e02e7 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -1,35 +1,56 @@ """Provaides classes and functions for handling gRPC protocol.""" +import asyncio import base64 +import contextlib +import functools import re +import sys import urllib.parse -from collections.abc import AsyncIterable, AsyncIterator +from collections.abc import AsyncIterable, AsyncIterator, Callable, Mapping from http import HTTPMethod from typing import Any +from urllib.parse import unquote +import httpcore +from google.protobuf.message import DecodeError from google.rpc import status_pb2 +from yarl import URL +from connect.byte_stream import HTTPCoreResponseAsyncByteStream from connect.code import Code from connect.codec import Codec, CodecNameType -from connect.compression import COMPRESSION_IDENTITY, Compression -from connect.connect import Address, AsyncContentStream, Peer, Spec, StreamingHandlerConn +from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name +from connect.connect import ( + Address, + Peer, + Spec, + StreamingClientConn, + StreamingHandlerConn, + StreamType, +) from connect.envelope import EnvelopeReader, EnvelopeWriter -from connect.error import ConnectError -from connect.headers import Headers +from connect.error import ConnectError, ErrorDetail +from connect.headers import Headers, include_request_headers from connect.protocol import ( HEADER_CONTENT_TYPE, + HEADER_USER_AGENT, PROTOCOL_GRPC, + PROTOCOL_GRPC_WEB, Protocol, ProtocolClient, ProtocolClientParams, ProtocolHandler, ProtocolHandlerParams, + code_from_http_status, exclude_protocol_headers, negotiate_compression, ) from connect.request import Request +from connect.session import AsyncClientSession from connect.streaming_response import StreamingResponse -from connect.utils import aiterate +from connect.utils import map_httpcore_exceptions +from connect.version import __version__ from connect.writer import ServerResponseWriter GRPC_HEADER_COMPRESSION = "Grpc-Encoding" @@ -44,23 +65,27 @@ GRPC_CONTENT_TYPE_PREFIX = GRPC_CONTENT_TYPE_DEFAULT + "+" GRPC_WEB_CONTENT_TYPE_PREFIX = GRPC_WEB_CONTENT_TYPE_DEFAULT + "+" +HEADER_X_USER_AGENT = "X-User-Agent" + GRPC_ALLOWED_METHODS = [HTTPMethod.POST] +DEFAULT_GRPC_USER_AGENT = f"connect-python/{__version__} (Python/{__version__})" + _RE = re.compile(r"^(\d{1,8})([HMSmun])$") _UNIT_TO_SECONDS = { - "H": 60 * 60, - "M": 60, - "S": 1, - "m": 1e-3, # millisecond - "u": 1e-6, # microsecond "n": 1e-9, # nanosecond + "u": 1e-6, # microsecond + "m": 1e-3, # millisecond + "S": 1.0, + "M": 60.0, + "H": 3600.0, } -_MAX_HOURS = (2**63 - 1) // (60 * 60 * 1_000_000_000) +_MAX_HOURS = sys.maxsize // (60 * 60 * 1_000_000_000) -class ProtocolGPRC(Protocol): - """ProtocolGPRC is a protocol implementation for handling gRPC and gRPC-Web requests. +class ProtocolGRPC(Protocol): + """ProtocolGRPC is a protocol implementation for handling gRPC and gRPC-Web requests. Attributes: web (bool): Indicates whether to use gRPC-Web (True) or standard gRPC (False). @@ -115,7 +140,15 @@ def client(self, params: ProtocolClientParams) -> ProtocolClient: ProtocolClient: An instance of GRPCClient. """ - raise NotImplementedError("GRPC client is not implemented yet.") + peer = Peer( + address=Address(host=params.url.host or "", port=params.url.port or 80), + protocol=PROTOCOL_GRPC, + query={}, + ) + if self.web: + peer.protocol = PROTOCOL_GRPC_WEB + + return GRPCClient(params, peer, self.web) class GRPCHandler(ProtocolHandler): @@ -233,7 +266,6 @@ async def conn( spec=self.params.spec, peer=peer, marshaler=GRPCMarshaler( - self.web, codec, response_compression, self.params.compress_min_bytes, @@ -257,11 +289,123 @@ async def conn( return conn +class GRPCClient(ProtocolClient): + """GRPCClient is a protocol client implementation for gRPC communication, supporting both standard and web environments. + + Attributes: + params (ProtocolClientParams): Configuration parameters for the protocol client, including codec, compression, session, and URL. + _peer (Peer): The peer instance associated with this client, representing the remote endpoint. + web (bool): Indicates whether the client is running in a web environment, affecting header and content-type handling. + + """ + + params: ProtocolClientParams + _peer: Peer + web: bool + + def __init__(self, params: ProtocolClientParams, peer: Peer, web: bool) -> None: + """Initialize the ProtocolClient with the given parameters. + + Args: + params (ProtocolClientParams): The parameters for the protocol client. + peer (Peer): The peer instance to be used. + web (bool): Indicates whether the client is running in a web environment. + + """ + self.params = params + self._peer = peer + self.web = web + + @property + def peer(self) -> Peer: + """Returns the associated Peer object. + + Returns: + Peer: The peer instance associated with this object. + + """ + return self._peer + + def write_request_headers(self, _: StreamType, headers: Headers) -> None: + """Set and modifies HTTP/2 or gRPC request headers based on the stream type, connection parameters, and environment. + + Args: + stream_type (StreamType): The type of stream for which headers are being written. + headers (Headers): The dictionary of headers to be modified or populated. + + Behavior: + - Ensures the 'User-Agent' header is set to the default gRPC user agent if not already present. + - If running in a web environment, also sets the 'X-User-Agent' header. + - Sets the 'Content-Type' header according to the codec name and environment. + - Sets the 'Accept-Encoding' header to indicate supported compression. + - If a specific compression is configured and is not the identity, sets the gRPC compression header. + - If multiple compressions are supported, sets the gRPC accept compression header with the supported values. + - For non-web environments, adds the 'Te: trailers' header required for gRPC. + + Note: + This method mutates the provided headers dictionary in place. + + """ + if headers.get(HEADER_USER_AGENT, None) is None: + headers[HEADER_USER_AGENT] = DEFAULT_GRPC_USER_AGENT + + if self.web and headers.get(HEADER_X_USER_AGENT, None) is None: + headers[HEADER_X_USER_AGENT] = DEFAULT_GRPC_USER_AGENT + + headers[HEADER_CONTENT_TYPE] = grpc_content_type_from_codec_name(self.web, self.params.codec.name) + + headers["Accept-Encoding"] = COMPRESSION_IDENTITY + if self.params.compression_name and self.params.compression_name != COMPRESSION_IDENTITY: + headers[GRPC_HEADER_COMPRESSION] = self.params.compression_name + + if self.params.compressions: + headers[GRPC_HEADER_ACCEPT_COMPRESSION] = ", ".join(c.name for c in self.params.compressions) + + if not self.web: + headers["Te"] = "trailers" + + def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: + """Create and returns a GRPCClientConn instance configured with the provided specification and headers. + + Args: + spec (Spec): The specification object defining the protocol or service interface. + headers (Headers): The request headers to include in the connection. + + Returns: + StreamingClientConn: An initialized gRPC streaming client connection. + + Details: + - Configures the connection with parameters such as session, peer, URL, codec, and compression settings. + - Initializes GRPCMarshaler and GRPCUnmarshaler with appropriate codecs and limits. + - Compression is determined using the provided compression name and available compressions. + + """ + return GRPCClientConn( + web=self.web, + session=self.params.session, + spec=spec, + peer=self.peer, + url=self.params.url, + codec=self.params.codec, + compressions=self.params.compressions, + marshaler=GRPCMarshaler( + codec=self.params.codec, + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), + ), + unmarshaler=GRPCUnmarshaler( + codec=self.params.codec, + read_max_bytes=self.params.read_max_bytes, + ), + request_headers=headers, + ) + + class GRPCMarshaler(EnvelopeWriter): """GRPCMarshaler is responsible for marshaling messages into the gRPC wire format. Args: - web (bool): Indicates whether to use the gRPC-web protocol. codec (Codec | None): The codec used for encoding/decoding messages. compression (Compression | None): The compression algorithm to use, if any. compress_min_bytes (int): Minimum message size in bytes before compression is applied. @@ -274,11 +418,8 @@ class GRPCMarshaler(EnvelopeWriter): """ - web: bool - def __init__( self, - web: bool, codec: Codec | None, compression: Compression | None, compress_min_bytes: int, @@ -287,7 +428,6 @@ def __init__( """Initialize the protocol with the specified configuration. Args: - web (bool): Indicates whether the protocol is used in a web context. codec (Codec | None): The codec to use for encoding/decoding messages, or None for default. compression (Compression | None): The compression algorithm to use, or None for no compression. compress_min_bytes (int): The minimum number of bytes before compression is applied. @@ -298,7 +438,6 @@ def __init__( """ super().__init__(codec, compression, compress_min_bytes, send_max_bytes) - self.web = web class GRPCUnmarshaler(EnvelopeReader): @@ -349,6 +488,273 @@ async def unmarshal(self, message: Any) -> AsyncIterator[Any]: yield obj +EventHook = Callable[..., Any] + + +class GRPCClientConn(StreamingClientConn): + """GRPCClientConn is a gRPC client connection implementation supporting asynchronous streaming requests and responses over HTTP/2. + + This class manages the lifecycle of a gRPC client connection, including marshaling and unmarshaling messages, handling request and response headers/trailers, managing compression, and supporting event hooks for request/response events. It integrates with an asynchronous HTTP client session and supports cancellation via asyncio events. + + Attributes: + session (AsyncClientSession): The asynchronous client session used for HTTP requests. + _spec (Spec): The protocol or API specification. + _peer (Peer): Information about the remote peer. + url (URL): The endpoint URL for the connection. + codec (Codec | None): Codec for encoding/decoding messages. + compressions (list[Compression]): Supported compression algorithms. + marshaler (GRPCMarshaler): Marshaler for serializing messages. + unmarshaler (GRPCUnmarshaler): Unmarshaler for deserializing messages. + _response_headers (Headers): HTTP response headers. + _response_trailers (Headers): HTTP response trailers. + _request_headers (Headers): HTTP request headers. + receive_trailers (Callable[[], None] | None): Callback to receive trailers after response. + + """ + + web: bool + session: AsyncClientSession + _spec: Spec + _peer: Peer + url: URL + codec: Codec | None + compressions: list[Compression] + marshaler: GRPCMarshaler + unmarshaler: GRPCUnmarshaler + _response_headers: Headers + _response_trailers: Headers + _request_headers: Headers + receive_trailers: Callable[[], None] | None + + def __init__( + self, + web: bool, + session: AsyncClientSession, + spec: Spec, + peer: Peer, + url: URL, + codec: Codec | None, + compressions: list[Compression], + request_headers: Headers, + marshaler: GRPCMarshaler, + unmarshaler: GRPCUnmarshaler, + event_hooks: None | (Mapping[str, list[EventHook]]) = None, + ) -> None: + """Initialize a new instance of the class. + + Args: + web (bool): Indicates if the connection is for a web environment. + session (AsyncClientSession): The asynchronous client session to use for requests. + spec (Spec): The specification object describing the protocol or API. + peer (Peer): The peer information for the connection. + url (URL): The URL endpoint for the connection. + codec (Codec | None): The codec to use for encoding/decoding messages, or None. + compressions (list[Compression]): List of supported compression algorithms. + request_headers (Headers): Headers to include in outgoing requests. + marshaler (GRPCMarshaler): The marshaler for serializing messages. + unmarshaler (GRPCUnmarshaler): The unmarshaler for deserializing messages. + event_hooks (None | Mapping[str, list[EventHook]], optional): Optional mapping of event hooks for "request" and "response" events. Defaults to None. + + """ + event_hooks = {} if event_hooks is None else event_hooks + + self.web = web + self.session = session + self._spec = spec + self._peer = peer + self.url = url + self.codec = codec + self.compressions = compressions + self.marshaler = marshaler + self.unmarshaler = unmarshaler + self._response_headers = Headers() + self._response_trailers = Headers() + self._request_headers = request_headers + + self._event_hooks = { + "request": list(event_hooks.get("request", [])), + "response": list(event_hooks.get("response", [])), + } + + @property + def spec(self) -> Spec: + """Return the specification details.""" + return self._spec + + @property + def peer(self) -> Peer: + """Return the peer information.""" + raise NotImplementedError() + + async def receive(self, message: Any, abort_event: asyncio.Event | None) -> AsyncIterator[Any]: + """Receives a message and processes it.""" + async for obj in self.unmarshaler.unmarshal(message): + if abort_event and abort_event.is_set(): + raise ConnectError("receive operation aborted", Code.CANCELED) + + yield obj + + if callable(self.receive_trailers): + self.receive_trailers() + + if self.unmarshaler.bytes_read == 0 and len(self.response_trailers) == 0: + self.response_trailers.update(self._response_headers) + del self._response_headers[HEADER_CONTENT_TYPE] + + server_error = grpc_error_from_trailer(self.response_trailers) + if server_error: + server_error.metadata = self.response_headers.copy() + raise server_error + + server_error = grpc_error_from_trailer(self.response_trailers) + if server_error: + server_error.metadata = self.response_headers.copy() + server_error.metadata.update(self.response_trailers) + raise server_error + + def _receive_trailers(self, response: httpcore.Response) -> None: + if "trailing_headers" not in response.extensions: + return + + trailers = response.extensions["trailing_headers"] + self._response_trailers.update(Headers(trailers)) + + @property + def request_headers(self) -> Headers: + """Return the request headers.""" + return self._request_headers + + async def send( + self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None + ) -> None: + """Send a gRPC request asynchronously using HTTP/2 via httpcore, handling streaming messages, timeouts, and abort events. + + Args: + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled and sent as the request body. + timeout (float | None): Optional timeout in seconds for the request. If provided, sets the gRPC timeout header. + abort_event (asyncio.Event | None): Optional asyncio event that, if set, will abort the request and raise a cancellation error. + + Raises: + ConnectError: If the request is aborted before or during execution, or if an error occurs during the HTTP request. + + Side Effects: + - Invokes registered request and response event hooks. + - Sets up the response stream and trailers for further processing. + - Validates the HTTP response. + + """ + if abort_event and abort_event.is_set(): + raise ConnectError("request aborted", Code.CANCELED) + + extensions = {} + if timeout: + extensions["timeout"] = {"read": timeout} + self._request_headers[GRPC_HEADER_TIMEOUT] = grpc_encode_timeout(timeout) + + content_iterator = self.marshaler.marshal(messages) + + request = httpcore.Request( + method=HTTPMethod.POST, + url=httpcore.URL( + scheme=self.url.scheme, + host=self.url.host or "", + port=self.url.port, + target=self.url.raw_path, + ), + headers=list( + include_request_headers( + headers=self._request_headers, url=self.url, content=content_iterator, method=HTTPMethod.POST + ).items() + ), + content=content_iterator, + extensions=extensions, + ) + + for hook in self._event_hooks["request"]: + hook(request) + + with map_httpcore_exceptions(): + if not abort_event: + response = await self.session.pool.handle_async_request(request) + else: + request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) + abort_task = asyncio.create_task(abort_event.wait()) + + done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) + + if abort_task in done: + request_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await request_task + + raise ConnectError("request aborted", Code.CANCELED) + + abort_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await abort_task + + response = await request_task + + for hook in self._event_hooks["response"]: + hook(response) + + assert isinstance(response.stream, AsyncIterable) + self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(aiterator=response.stream) + self.receive_trailers = functools.partial(self._receive_trailers, response) + + await self._validate_response(response) + + async def _validate_response(self, response: httpcore.Response) -> None: + response_headers = Headers(response.headers) + if response.status != 200: + raise ConnectError( + f"HTTP {response.status}", + code_from_http_status(response.status), + ) + + grpc_validate_response_content_type( + self.web, + self.marshaler.codec.name if self.marshaler.codec else "", + response_headers.get(HEADER_CONTENT_TYPE, ""), + ) + + compression = response_headers.get(GRPC_HEADER_COMPRESSION, None) + if compression and compression != COMPRESSION_IDENTITY: + self.unmarshaler.compression = get_compresion_from_name(compression, self.compressions) + + self._response_headers.update(response_headers) + + @property + def response_headers(self) -> Headers: + """Return the response headers.""" + return self._response_headers + + @property + def response_trailers(self) -> Headers: + """Return response trailers.""" + return self._response_trailers + + def on_request_send(self, fn: EventHook) -> None: + """Register a callback function to be invoked when a request is sent. + + Args: + fn (EventHook): The callback function to be added to the "request" event hook. + + Returns: + None + + """ + self._event_hooks["request"].append(fn) + + async def aclose(self) -> None: + """Asynchronously closes the underlying unmarshaler resource. + + This method should be called to properly release any resources held by the unmarshaler, + such as open network connections or file handles, when they are no longer needed. + """ + await self.unmarshaler.aclose() + + class GRPCHandlerConn(StreamingHandlerConn): """GRPCHandlerConn is a handler class for managing gRPC protocol connections within a streaming server context. @@ -432,7 +838,6 @@ def parse_timeout(self) -> float | None: num_str, unit = m.groups() num = int(num_str) - if num > 99_999_999: raise ConnectError(f"protocol error: timeout {timeout!r} is too long") @@ -462,7 +867,7 @@ def peer(self) -> Peer: """ return self._peer - def receive(self, message: Any) -> AsyncContentStream[Any]: + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message and processes it. Args: @@ -473,7 +878,7 @@ def receive(self, message: Any) -> AsyncContentStream[Any]: this will yield exactly one item. """ - return AsyncContentStream(self.unmarshaler.unmarshal(message), self.spec.stream_type) + return self.unmarshaler.unmarshal(message) @property def request_headers(self) -> Headers: @@ -564,7 +969,10 @@ async def send_error(self, error: ConnectError) -> None: await self.writer.write( StreamingResponse( - content=aiterate([b""]), headers=self.response_headers, trailers=self.response_trailers, status_code=200 + content=[], + headers=self.response_headers, + trailers=self.response_trailers, + status_code=200, ) ) @@ -625,12 +1033,203 @@ def grpc_error_to_trailer(trailer: Headers, error: ConnectError | None) -> None: ) code = status.code message = status.message - bin = None + details_binary = None if len(status.details) > 0: - bin = status.SerializeToString() + details_binary = status.SerializeToString() trailer[GRPC_HEADER_STATUS] = str(code) trailer[GRPC_HEADER_MESSAGE] = urllib.parse.quote(message) - if bin: - trailer[GRPC_HEADER_DETAILS] = base64.b64encode(bin).decode().rstrip("=") + if details_binary: + trailer[GRPC_HEADER_DETAILS] = base64.b64encode(details_binary).decode().rstrip("=") + + +def grpc_content_type_from_codec_name(web: bool, codec_name: str) -> str: + """Return the appropriate gRPC content type string based on the given codec name and whether the request is for gRPC-Web. + + Args: + web (bool): Indicates if the content type is for gRPC-Web (True) or standard gRPC (False). + codec_name (str): The name of the codec (e.g., "proto", "json"). + + Returns: + str: The corresponding gRPC content type string. + + """ + if web: + return GRPC_WEB_CONTENT_TYPE_PREFIX + codec_name + + if codec_name == CodecNameType.PROTO: + return GRPC_CONTENT_TYPE_DEFAULT + + return GRPC_CONTENT_TYPE_PREFIX + codec_name + + +def grpc_validate_response_content_type(web: bool, request_codec_name: str, response_content_type: str) -> None: + """Validate that the gRPC response content type matches the expected value based on the request codec and whether gRPC-Web is used. + + Args: + web (bool): Indicates if gRPC-Web is being used. + request_codec_name (str): The name of the codec used in the request (e.g., "proto", "json"). + response_content_type (str): The content type returned in the response. + + Raises: + ConnectError: If the response content type does not match the expected value, with an appropriate error code. + + """ + bare, prefix = GRPC_CONTENT_TYPE_DEFAULT, GRPC_CONTENT_TYPE_PREFIX + if web: + bare, prefix = GRPC_WEB_CONTENT_TYPE_DEFAULT, GRPC_WEB_CONTENT_TYPE_PREFIX + + if response_content_type == prefix + request_codec_name or ( + request_codec_name == CodecNameType.PROTO and response_content_type == bare + ): + return + + expected_content_type = bare + if request_codec_name != CodecNameType.PROTO: + expected_content_type = prefix + request_codec_name + + code = Code.INTERNAL + if response_content_type != bare and not response_content_type.startswith(prefix): + code = Code.UNKNOWN + + raise ConnectError(f"invalid content-type {response_content_type}, expected {expected_content_type}", code) + + +def grpc_error_from_trailer(trailers: Headers) -> ConnectError | None: + """Parse gRPC error information from response trailers and constructs a ConnectError if present. + + Args: + trailers (Headers): The gRPC response trailers containing error information. + + Returns: + ConnectError | None: Returns a ConnectError instance if an error is found in the trailers, + or None if the status code indicates success. + + Raises: + ConnectError: If the grpc-status-details-bin trailer or protobuf error details are invalid. + + The function extracts the gRPC status code, error message, and optional error details from the trailers. + If the status code is missing or invalid, it returns a ConnectError with an appropriate message. + If the status code indicates success ("0"), it returns None. + If error details are present and valid, they are attached to the ConnectError. + + """ + code_header = trailers.get(GRPC_HEADER_STATUS) + if code_header is None: + code = Code.UNKNOWN + if len(trailers) == 0: + code = Code.INTERNAL + + return ConnectError( + f"protocol error: no {GRPC_HEADER_STATUS} header in trailers", + code, + ) + + if code_header == "0": + return None + + try: + code = Code(int(code_header)) + except ValueError: + return ConnectError( + f"protocol error: invalid error code {code_header} in trailers", + ) + + try: + message = unquote(trailers.get(GRPC_HEADER_MESSAGE, "")) + except Exception: + return ConnectError( + f"protocol error: invalid error message {code_header} in trailers", + code=Code.UNKNOWN, + ) + + ret_error = ConnectError( + message, + code, + wire_error=True, + ) + + details_binary_encoded = trailers.get(GRPC_HEADER_DETAILS, None) + if details_binary_encoded and len(details_binary_encoded) > 0: + try: + details_binary = decode_binary_header(details_binary_encoded) + except Exception as e: + raise ConnectError( + f"server returned invalid grpc-status-details-bin trailer: {e}", + code=Code.INTERNAL, + ) from e + + status = status_pb2.Status() + try: + status.ParseFromString(details_binary) + except DecodeError as e: + raise ConnectError( + f"server returned invalid protobuf for error details: {e}", + code=Code.INTERNAL, + ) from e + + for detail in status.details: + ret_error.details.append(ErrorDetail(pb_any=detail)) + + ret_error.code = Code(status.code) + ret_error.raw_message = status.message + + return ret_error + + +def decode_binary_header(data: str) -> bytes: + """Decode a base64-encoded string representing a binary header. + + If the input string's length is not a multiple of 4, it pads the string with '=' characters + to make it valid base64 before decoding. + + Args: + data (str): The base64-encoded string to decode. + + Returns: + bytes: The decoded binary data. + + Raises: + binascii.Error: If the input is not correctly base64-encoded. + + """ + if len(data) % 4: + data += "=" * (-len(data) % 4) + + return base64.b64decode(data, validate=True) + + +def grpc_encode_timeout(timeout: float) -> str: + """Encode a timeout value (in seconds) into the gRPC timeout format string. + + The gRPC timeout format is a decimal number with a time unit suffix, where the unit can be: + - 'H' for hours + - 'M' for minutes + - 'S' for seconds + - 'm' for milliseconds + - 'u' for microseconds + - 'n' for nanoseconds + + If the timeout is less than or equal to zero, returns "0n". + + Args: + timeout (float): The timeout value in seconds. + + Returns: + str: The timeout encoded as a gRPC timeout string. + + """ + if timeout <= 0: + return "0n" + + grpc_timeout_max_value = 10**8 + + _units = dict(sorted(_UNIT_TO_SECONDS.items(), key=lambda item: item[1])) + for unit, size in _units.items(): + if timeout < size * grpc_timeout_max_value: + value = int(timeout / size) + return f"{value}{unit}" + + value = int(timeout / 3600.0) + return f"{value}H" diff --git a/uv.lock b/uv.lock index b388c30..68c5cd5 100644 --- a/uv.lock +++ b/uv.lock @@ -130,7 +130,7 @@ requires-dist = [ { name = "anyio", specifier = ">=4.7.0" }, { name = "googleapis-common-protos", specifier = ">=1.70.0" }, { name = "h2", specifier = ">=4.2.0" }, - { name = "httpcore", specifier = ">=1.0.7" }, + { name = "httpcore", git = "https://github.com/tsubakiky/httpcore" }, { name = "protobuf", specifier = ">=5.29.1" }, { name = "pydantic", specifier = ">=2.10.4" }, { name = "starlette", specifier = ">=0.46.0" }, @@ -205,11 +205,11 @@ wheels = [ [[package]] name = "h11" -version = "0.14.0" +version = "0.16.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418, upload_time = "2022-09-25T15:40:01.519Z" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload_time = "2025-04-24T03:35:25.427Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259, upload_time = "2022-09-25T15:39:59.68Z" }, + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload_time = "2025-04-24T03:35:24.344Z" }, ] [[package]] @@ -236,16 +236,12 @@ wheels = [ [[package]] name = "httpcore" -version = "1.0.7" -source = { registry = "https://pypi.org/simple" } +version = "1.0.9" +source = { git = "https://github.com/tsubakiky/httpcore#e70c821d72d7b9c5634c781cd454ced911052c29" } dependencies = [ { name = "certifi" }, { name = "h11" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6a/41/d7d0a89eb493922c37d343b607bc1b5da7f5be7e383740b4753ad8943e90/httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c", size = 85196, upload_time = "2024-11-15T12:30:47.531Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551, upload_time = "2024-11-15T12:30:45.782Z" }, -] [[package]] name = "httpx"