From a6e04e94968c2592eb5a5b46d8b6a07716e7f3a5 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Fri, 2 May 2025 03:25:22 +0900 Subject: [PATCH 01/14] protocol_grpc: add grpc client --- conformance/client_config.yaml | 7 +- conformance/client_runner.py | 5 + conformance/run-testcase.txt | 2 +- src/connect/byte_stream.py | 54 ++++++ src/connect/client.py | 9 +- src/connect/handler.py | 4 +- src/connect/options.py | 6 + src/connect/protocol_connect.py | 46 +----- src/connect/protocol_grpc.py | 282 ++++++++++++++++++++++++++++++-- 9 files changed, 347 insertions(+), 68 deletions(-) create mode 100644 src/connect/byte_stream.py diff --git a/conformance/client_config.yaml b/conformance/client_config.yaml index 22fb4f4..5817efa 100644 --- a/conformance/client_config.yaml +++ b/conformance/client_config.yaml @@ -3,17 +3,18 @@ features: - HTTP_VERSION_1 - HTTP_VERSION_2 protocols: - - PROTOCOL_CONNECT + # - PROTOCOL_CONNECT + - PROTOCOL_GRPC codecs: - CODEC_PROTO compressions: - COMPRESSION_IDENTITY - COMPRESSION_GZIP stream_types: - - STREAM_TYPE_UNARY + # - STREAM_TYPE_UNARY - STREAM_TYPE_CLIENT_STREAM - STREAM_TYPE_SERVER_STREAM - - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM + # - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM # - STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM supports_h2c: true diff --git a/conformance/client_runner.py b/conformance/client_runner.py index b5203a8..adb18cb 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -218,6 +218,11 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c payloads = [] try: options = ClientOptions() + if msg.protocol == config_pb2.PROTOCOL_GRPC: + options.grpc = True + if msg.protocol == config_pb2.PROTOCOL_GRPC_WEB: + options.grpc = True + if msg.compression == config_pb2.COMPRESSION_GZIP: options.request_compression_name = "gzip" diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt index 26b35a5..92362ec 100644 --- a/conformance/run-testcase.txt +++ b/conformance/run-testcase.txt @@ -1 +1 @@ -gRPC Unexpected Requests/HTTPVersion:2/TLS:true/unary/multiple-requests +Basic/HTTPVersion:2/Protocol:PROTOCOL_GRPC/Codec:CODEC_PROTO/Compression:COMPRESSION_IDENTITY/TLS:false/client-stream/success diff --git a/src/connect/byte_stream.py b/src/connect/byte_stream.py new file mode 100644 index 0000000..45bd122 --- /dev/null +++ b/src/connect/byte_stream.py @@ -0,0 +1,54 @@ +from collections.abc import ( + AsyncIterable, + AsyncIterator, +) + +from connect.utils import ( + AsyncByteStream, + get_acallable_attribute, + map_httpcore_exceptions, +) + + +class HTTPCoreResponseAsyncByteStream(AsyncByteStream): + """An asynchronous byte stream for reading and writing byte chunks.""" + + aiterator: AsyncIterable[bytes] | None + _closed: bool + + def __init__( + self, + aiterator: AsyncIterable[bytes] | None = None, + ) -> None: + """Initialize the protocol connect instance. + + Args: + aiterator (AsyncIterable[bytes] | None): An optional asynchronous iterable of bytes. + + Returns: + None + + """ + self.aiterator = aiterator + self._closed = False + + async def __aiter__(self) -> AsyncIterator[bytes]: + """Asynchronous iterator method to read byte chunks from the stream.""" + if self.aiterator: + try: + with map_httpcore_exceptions(): + async for chunk in self.aiterator: + yield chunk + except BaseException as exc: + await self.aclose() + raise exc + + async def aclose(self) -> None: + """Asynchronously close the stream.""" + if not self._closed and self.aiterator: + aclose = get_acallable_attribute(self.aiterator, "aclose") + if not aclose: + return + + with map_httpcore_exceptions(): + await aclose() diff --git a/src/connect/client.py b/src/connect/client.py index b3393d0..451f200 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -27,8 +27,9 @@ from connect.idempotency_level import IdempotencyLevel from connect.interceptor import apply_interceptors from connect.options import ClientOptions -from connect.protocol import ProtocolClient, ProtocolClientParams +from connect.protocol import Protocol, ProtocolClient, ProtocolClientParams from connect.protocol_connect import ProtocolConnect +from connect.protocol_grpc import ProtocolGRPC from connect.session import AsyncClientSession @@ -75,7 +76,7 @@ class ClientConfig: """ url: URL - protocol: ProtocolConnect + protocol: Protocol procedure: str codec: Codec request_compression_name: str | None @@ -113,6 +114,10 @@ def __init__(self, raw_url: str, options: ClientOptions): self.url = url self.protocol = ProtocolConnect() + if options.grpc: + self.protocol = ProtocolGRPC(web=False) + if options.grpc_web: + self.protocol = ProtocolGRPC(web=True) self.procedure = proto_path self.codec = ProtoBinaryCodec() self.request_compression_name = options.request_compression_name diff --git a/src/connect/handler.py b/src/connect/handler.py index 1d462e4..a68849a 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -41,7 +41,7 @@ from connect.protocol_connect import ( ProtocolConnect, ) -from connect.protocol_grpc import ProtocolGPRC +from connect.protocol_grpc import ProtocolGRPC from connect.request import Request from connect.response import Response from connect.utils import aiterate @@ -123,7 +123,7 @@ def create_protocol_handlers(config: HandlerConfig) -> list[ProtocolHandler]: list[ProtocolHandler]: A list of initialized protocol handlers. """ - protocols = [ProtocolConnect(), ProtocolGPRC(web=False)] + protocols = [ProtocolConnect(), ProtocolGRPC(web=False)] codecs = CodecMap(config.codecs) diff --git a/src/connect/options.py b/src/connect/options.py index 2954648..94f1190 100644 --- a/src/connect/options.py +++ b/src/connect/options.py @@ -84,6 +84,12 @@ class ClientOptions(BaseModel): enable_get: bool = Field(default=False) """A boolean indicating whether to enable GET requests.""" + grpc: bool = Field(default=False) + """A boolean indicating whether to use gRPC.""" + + grpc_web: bool = Field(default=False) + """A boolean indicating whether to use gRPC-Web.""" + def merge(self, override_options: "ClientOptions | None" = None) -> "ClientOptions": """Merge this options object with an override options object. diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 580a437..8ae087c 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -20,6 +20,7 @@ from google.protobuf import json_format from yarl import URL +from connect.byte_stream import HTTPCoreResponseAsyncByteStream from connect.code import Code from connect.codec import Codec, CodecNameType, StableCodec from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name @@ -58,7 +59,6 @@ from connect.session import AsyncClientSession from connect.streaming_response import StreamingResponse from connect.utils import ( - AsyncByteStream, aiterate, get_acallable_attribute, map_httpcore_exceptions, @@ -1106,50 +1106,6 @@ def _write_with_get(self, url: URL) -> None: self.url = url -class HTTPCoreResponseAsyncByteStream(AsyncByteStream): - """An asynchronous byte stream for reading and writing byte chunks.""" - - aiterator: AsyncIterable[bytes] | None - _closed: bool - - def __init__( - self, - aiterator: AsyncIterable[bytes] | None = None, - ) -> None: - """Initialize the protocol connect instance. - - Args: - aiterator (AsyncIterable[bytes] | None): An optional asynchronous iterable of bytes. - - Returns: - None - - """ - self.aiterator = aiterator - self._closed = False - - async def __aiter__(self) -> AsyncIterator[bytes]: - """Asynchronous iterator method to read byte chunks from the stream.""" - if self.aiterator: - try: - with map_httpcore_exceptions(): - async for chunk in self.aiterator: - yield chunk - except BaseException as exc: - await self.aclose() - raise exc - - async def aclose(self) -> None: - """Asynchronously close the stream.""" - if not self._closed and self.aiterator: - aclose = get_acallable_attribute(self.aiterator, "aclose") - if not aclose: - return - - with map_httpcore_exceptions(): - await aclose() - - class ConnectStreamingMarshaler(EnvelopeWriter): """A class responsible for marshaling messages with optional compression. diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 90b4b97..688d4a0 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -1,24 +1,40 @@ """Provaides classes and functions for handling gRPC protocol.""" +import asyncio import base64 +import contextlib import re import urllib.parse -from collections.abc import AsyncIterable, AsyncIterator +from collections.abc import AsyncIterable, AsyncIterator, Callable, Mapping from http import HTTPMethod from typing import Any +import httpcore from google.rpc import status_pb2 +from yarl import URL +from connect.byte_stream import HTTPCoreResponseAsyncByteStream from connect.code import Code from connect.codec import Codec, CodecNameType -from connect.compression import COMPRESSION_IDENTITY, Compression -from connect.connect import Address, AsyncContentStream, Peer, Spec, StreamingHandlerConn +from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name +from connect.connect import ( + Address, + AsyncContentStream, + Peer, + Spec, + StreamingClientConn, + StreamingHandlerConn, + StreamType, + UnaryClientConn, +) from connect.envelope import EnvelopeReader, EnvelopeWriter from connect.error import ConnectError -from connect.headers import Headers +from connect.headers import Headers, include_request_headers from connect.protocol import ( HEADER_CONTENT_TYPE, + HEADER_USER_AGENT, PROTOCOL_GRPC, + PROTOCOL_GRPC_WEB, Protocol, ProtocolClient, ProtocolClientParams, @@ -28,8 +44,10 @@ negotiate_compression, ) from connect.request import Request +from connect.session import AsyncClientSession from connect.streaming_response import StreamingResponse -from connect.utils import aiterate +from connect.utils import aiterate, map_httpcore_exceptions +from connect.version import __version__ from connect.writer import ServerResponseWriter GRPC_HEADER_COMPRESSION = "Grpc-Encoding" @@ -44,8 +62,12 @@ GRPC_CONTENT_TYPE_PREFIX = GRPC_CONTENT_TYPE_DEFAULT + "+" GRPC_WEB_CONTENT_TYPE_PREFIX = GRPC_WEB_CONTENT_TYPE_DEFAULT + "+" +HEADER_X_USER_AGENT = "X-User-Agent" + GRPC_ALLOWED_METHODS = [HTTPMethod.POST] +DEFAULT_GRPC_USER_AGENT = f"connect-python/{__version__} (Python/{__version__})" + _RE = re.compile(r"^(\d{1,8})([HMSmun])$") _UNIT_TO_SECONDS = { @@ -59,8 +81,8 @@ _MAX_HOURS = (2**63 - 1) // (60 * 60 * 1_000_000_000) -class ProtocolGPRC(Protocol): - """ProtocolGPRC is a protocol implementation for handling gRPC and gRPC-Web requests. +class ProtocolGRPC(Protocol): + """ProtocolGRPC is a protocol implementation for handling gRPC and gRPC-Web requests. Attributes: web (bool): Indicates whether to use gRPC-Web (True) or standard gRPC (False). @@ -115,7 +137,15 @@ def client(self, params: ProtocolClientParams) -> ProtocolClient: ProtocolClient: An instance of GRPCClient. """ - raise NotImplementedError("GRPC client is not implemented yet.") + peer = Peer( + address=Address(host=params.url.host or "", port=params.url.port or 80), + protocol=PROTOCOL_GRPC, + query={}, + ) + if self.web: + peer.protocol = PROTOCOL_GRPC_WEB + + return GRPCClient(params, peer, self.web) class GRPCHandler(ProtocolHandler): @@ -233,7 +263,6 @@ async def conn( spec=self.params.spec, peer=peer, marshaler=GRPCMarshaler( - self.web, codec, response_compression, self.params.compress_min_bytes, @@ -257,11 +286,70 @@ async def conn( return conn +class GRPCClient(ProtocolClient): + params: ProtocolClientParams + _peer: Peer + web: bool + + def __init__(self, params: ProtocolClientParams, peer: Peer, web: bool) -> None: + self.params = params + self._peer = peer + self.web = web + + @property + def peer(self) -> Peer: + return self._peer + + def write_request_headers(self, stream_type: StreamType, headers: Headers) -> None: + if headers.get(HEADER_USER_AGENT, None) is None: + headers[HEADER_USER_AGENT] = DEFAULT_GRPC_USER_AGENT + + if self.web and headers.get(HEADER_X_USER_AGENT, None) is None: + headers[HEADER_X_USER_AGENT] = DEFAULT_GRPC_USER_AGENT + + headers[HEADER_CONTENT_TYPE] = grpc_content_type_from_codec_name(self.web, self.params.codec.name) + + headers["Accept-Encoding"] = COMPRESSION_IDENTITY + if self.params.compression_name and self.params.compression_name != COMPRESSION_IDENTITY: + headers[GRPC_HEADER_COMPRESSION] = self.params.compression_name + + if self.params.compressions: + headers[GRPC_HEADER_ACCEPT_COMPRESSION] = ", ".join(c.name for c in self.params.compressions) + + if not self.web: + headers["Te"] = "trailers" + + def conn(self, spec: Spec, headers: Headers) -> UnaryClientConn: + """Return the connection for the client.""" + raise NotImplementedError() + + def stream_conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: + """Return the streaming connection for the client.""" + return GRPCClientConn( + session=self.params.session, + spec=spec, + peer=self.peer, + url=self.params.url, + codec=self.params.codec, + compressions=self.params.compressions, + marshaler=GRPCMarshaler( + codec=self.params.codec, + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), + ), + unmarshaler=GRPCUnmarshaler( + codec=self.params.codec, + read_max_bytes=self.params.read_max_bytes, + ), + request_headers=headers, + ) + + class GRPCMarshaler(EnvelopeWriter): """GRPCMarshaler is responsible for marshaling messages into the gRPC wire format. Args: - web (bool): Indicates whether to use the gRPC-web protocol. codec (Codec | None): The codec used for encoding/decoding messages. compression (Compression | None): The compression algorithm to use, if any. compress_min_bytes (int): Minimum message size in bytes before compression is applied. @@ -274,11 +362,8 @@ class GRPCMarshaler(EnvelopeWriter): """ - web: bool - def __init__( self, - web: bool, codec: Codec | None, compression: Compression | None, compress_min_bytes: int, @@ -287,7 +372,6 @@ def __init__( """Initialize the protocol with the specified configuration. Args: - web (bool): Indicates whether the protocol is used in a web context. codec (Codec | None): The codec to use for encoding/decoding messages, or None for default. compression (Compression | None): The compression algorithm to use, or None for no compression. compress_min_bytes (int): The minimum number of bytes before compression is applied. @@ -298,7 +382,6 @@ def __init__( """ super().__init__(codec, compression, compress_min_bytes, send_max_bytes) - self.web = web class GRPCUnmarshaler(EnvelopeReader): @@ -349,6 +432,165 @@ async def unmarshal(self, message: Any) -> AsyncIterator[Any]: yield obj +EventHook = Callable[..., Any] + + +class GRPCClientConn(StreamingClientConn): + session: AsyncClientSession + _spec: Spec + _peer: Peer + url: URL + codec: Codec | None + compressions: list[Compression] + marshaler: GRPCMarshaler + unmarshaler: GRPCUnmarshaler + _response_headers: Headers + _response_trailers: Headers + _request_headers: Headers + + def __init__( + self, + session: AsyncClientSession, + spec: Spec, + peer: Peer, + url: URL, + codec: Codec | None, + compressions: list[Compression], + request_headers: Headers, + marshaler: GRPCMarshaler, + unmarshaler: GRPCUnmarshaler, + event_hooks: None | (Mapping[str, list[EventHook]]) = None, + ) -> None: + event_hooks = {} if event_hooks is None else event_hooks + + self.session = session + self._spec = spec + self._peer = peer + self.url = url + self.codec = codec + self.compressions = compressions + self.marshaler = marshaler + self.unmarshaler = unmarshaler + self.response_content = None + self._response_headers = Headers() + self._response_trailers = Headers() + self._request_headers = request_headers + + self._event_hooks = { + "request": list(event_hooks.get("request", [])), + "response": list(event_hooks.get("response", [])), + } + + @property + def spec(self) -> Spec: + """Return the specification details.""" + return self._spec + + @property + def peer(self) -> Peer: + """Return the peer information.""" + raise NotImplementedError() + + async def receive(self, message: Any, abort_event: asyncio.Event | None) -> AsyncIterator[Any]: + """Receives a message and processes it.""" + async for obj in self.unmarshaler.unmarshal(message): + if abort_event and abort_event.is_set(): + raise ConnectError("receive operation aborted", Code.CANCELED) + + yield obj + + @property + def request_headers(self) -> Headers: + """Return the request headers.""" + return self._request_headers + + async def send( + self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None + ) -> None: + if abort_event and abort_event.is_set(): + raise ConnectError("request aborted", Code.CANCELED) + + extensions = {} + if timeout: + extensions["timeout"] = {"read": timeout} + self._request_headers[GRPC_HEADER_TIMEOUT] = str(int(timeout * 1000)) + + content_iterator = self.marshaler.marshal(messages) + + request = httpcore.Request( + method=HTTPMethod.POST, + url=httpcore.URL( + scheme=self.url.scheme, + host=self.url.host or "", + port=self.url.port, + target=self.url.raw_path, + ), + headers=list( + include_request_headers( + headers=self._request_headers, url=self.url, content=content_iterator, method=HTTPMethod.POST + ).items() + ), + content=content_iterator, + extensions=extensions, + ) + + for hook in self._event_hooks["request"]: + hook(request) + + with map_httpcore_exceptions(): + if not abort_event: + response = await self.session.pool.handle_async_request(request) + else: + request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) + abort_task = asyncio.create_task(abort_event.wait()) + + done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) + + if abort_task in done: + request_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await request_task + + raise ConnectError("request aborted", Code.CANCELED) + + abort_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await abort_task + + response = await request_task + + for hook in self._event_hooks["response"]: + hook(response) + + assert isinstance(response.stream, AsyncIterable) + self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(aiterator=response.stream) + + await self._validate_response(response) + + async def _validate_response(self, response: httpcore.Response) -> None: + response_headers = Headers(response.headers) + + compression = response_headers.get(GRPC_HEADER_COMPRESSION, None) + self.unmarshaler.compression = get_compresion_from_name(compression, self.compressions) + self._response_headers.update(response_headers) + + @property + def response_headers(self) -> Headers: + """Return the response headers.""" + return self._response_headers + + @property + def response_trailers(self) -> Headers: + """Return response trailers.""" + return self._response_trailers + + def on_request_send(self, fn: EventHook) -> None: + self._event_hooks["request"].append(fn) + + async def aclose(self) -> None: + await self.unmarshaler.aclose() + + class GRPCHandlerConn(StreamingHandlerConn): """GRPCHandlerConn is a handler class for managing gRPC protocol connections within a streaming server context. @@ -634,3 +876,13 @@ def grpc_error_to_trailer(trailer: Headers, error: ConnectError | None) -> None: trailer[GRPC_HEADER_MESSAGE] = urllib.parse.quote(message) if bin: trailer[GRPC_HEADER_DETAILS] = base64.b64encode(bin).decode().rstrip("=") + + +def grpc_content_type_from_codec_name(web: bool, codec_name: str) -> str: + if web: + return GRPC_WEB_CONTENT_TYPE_PREFIX + codec_name + + if codec_name == CodecNameType.PROTO: + return GRPC_CONTENT_TYPE_DEFAULT + + return GRPC_CONTENT_TYPE_PREFIX + codec_name From 6ccae5b6c3fd2c181af72324eb35bf63e25da190 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Fri, 2 May 2025 06:13:26 +0900 Subject: [PATCH 02/14] protocol_grpc: receive trailer headers --- conformance/client_config.yaml | 2 +- conformance/uv.lock | 16 ++++++---------- pyproject.toml | 5 ++++- src/connect/protocol_grpc.py | 10 ++++++++++ uv.lock | 16 ++++++---------- 5 files changed, 27 insertions(+), 22 deletions(-) diff --git a/conformance/client_config.yaml b/conformance/client_config.yaml index 5817efa..facb36e 100644 --- a/conformance/client_config.yaml +++ b/conformance/client_config.yaml @@ -14,7 +14,7 @@ features: # - STREAM_TYPE_UNARY - STREAM_TYPE_CLIENT_STREAM - STREAM_TYPE_SERVER_STREAM - # - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM + - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM # - STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM supports_h2c: true diff --git a/conformance/uv.lock b/conformance/uv.lock index a539222..5f64f6d 100644 --- a/conformance/uv.lock +++ b/conformance/uv.lock @@ -108,7 +108,7 @@ requires-dist = [ { name = "anyio", specifier = ">=4.7.0" }, { name = "googleapis-common-protos", specifier = ">=1.70.0" }, { name = "h2", specifier = ">=4.2.0" }, - { name = "httpcore", specifier = ">=1.0.7" }, + { name = "httpcore", git = "https://github.com/tsubakiky/httpcore" }, { name = "protobuf", specifier = ">=5.29.1" }, { name = "pydantic", specifier = ">=2.10.4" }, { name = "starlette", specifier = ">=0.46.0" }, @@ -179,11 +179,11 @@ wheels = [ [[package]] name = "h11" -version = "0.14.0" +version = "0.16.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418, upload_time = "2022-09-25T15:40:01.519Z" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload_time = "2025-04-24T03:35:25.427Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259, upload_time = "2022-09-25T15:39:59.68Z" }, + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload_time = "2025-04-24T03:35:24.344Z" }, ] [[package]] @@ -210,16 +210,12 @@ wheels = [ [[package]] name = "httpcore" -version = "1.0.7" -source = { registry = "https://pypi.org/simple" } +version = "1.0.9" +source = { git = "https://github.com/tsubakiky/httpcore#e70c821d72d7b9c5634c781cd454ced911052c29" } dependencies = [ { name = "certifi" }, { name = "h11" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6a/41/d7d0a89eb493922c37d343b607bc1b5da7f5be7e383740b4753ad8943e90/httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c", size = 85196, upload_time = "2024-11-15T12:30:47.531Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551, upload_time = "2024-11-15T12:30:45.782Z" }, -] [[package]] name = "hypercorn" diff --git a/pyproject.toml b/pyproject.toml index be7b1d1..dcae5ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "anyio>=4.7.0", "googleapis-common-protos>=1.70.0", "h2>=4.2.0", - "httpcore>=1.0.7", + "httpcore", "protobuf>=5.29.1", "pydantic>=2.10.4", "starlette>=0.46.0", @@ -22,6 +22,9 @@ dependencies = [ "yarl>=1.18.3", ] +[tool.uv.sources] +httpcore = { git = "https://github.com/tsubakiky/httpcore" } + [tool.hatch.build.targets.wheel] packages = ["src/connect"] diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 688d4a0..c584c4b 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -3,6 +3,7 @@ import asyncio import base64 import contextlib +import functools import re import urllib.parse from collections.abc import AsyncIterable, AsyncIterator, Callable, Mapping @@ -447,6 +448,7 @@ class GRPCClientConn(StreamingClientConn): _response_headers: Headers _response_trailers: Headers _request_headers: Headers + receive_trailers: Callable[[], None] | None def __init__( self, @@ -499,6 +501,13 @@ async def receive(self, message: Any, abort_event: asyncio.Event | None) -> Asyn yield obj + if callable(self.receive_trailers): + self.receive_trailers() + + def _receive_trailers(self, response: httpcore.Response) -> None: + trailers = response.extensions["trailing_headers"] + self._response_trailers.update(Headers(trailers)) + @property def request_headers(self) -> Headers: """Return the request headers.""" @@ -564,6 +573,7 @@ async def send( assert isinstance(response.stream, AsyncIterable) self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(aiterator=response.stream) + self.receive_trailers = functools.partial(self._receive_trailers, response) await self._validate_response(response) diff --git a/uv.lock b/uv.lock index b388c30..68c5cd5 100644 --- a/uv.lock +++ b/uv.lock @@ -130,7 +130,7 @@ requires-dist = [ { name = "anyio", specifier = ">=4.7.0" }, { name = "googleapis-common-protos", specifier = ">=1.70.0" }, { name = "h2", specifier = ">=4.2.0" }, - { name = "httpcore", specifier = ">=1.0.7" }, + { name = "httpcore", git = "https://github.com/tsubakiky/httpcore" }, { name = "protobuf", specifier = ">=5.29.1" }, { name = "pydantic", specifier = ">=2.10.4" }, { name = "starlette", specifier = ">=0.46.0" }, @@ -205,11 +205,11 @@ wheels = [ [[package]] name = "h11" -version = "0.14.0" +version = "0.16.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418, upload_time = "2022-09-25T15:40:01.519Z" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload_time = "2025-04-24T03:35:25.427Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259, upload_time = "2022-09-25T15:39:59.68Z" }, + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload_time = "2025-04-24T03:35:24.344Z" }, ] [[package]] @@ -236,16 +236,12 @@ wheels = [ [[package]] name = "httpcore" -version = "1.0.7" -source = { registry = "https://pypi.org/simple" } +version = "1.0.9" +source = { git = "https://github.com/tsubakiky/httpcore#e70c821d72d7b9c5634c781cd454ced911052c29" } dependencies = [ { name = "certifi" }, { name = "h11" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6a/41/d7d0a89eb493922c37d343b607bc1b5da7f5be7e383740b4753ad8943e90/httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c", size = 85196, upload_time = "2024-11-15T12:30:47.531Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551, upload_time = "2024-11-15T12:30:45.782Z" }, -] [[package]] name = "httpx" From 2f47eb212d62d6b0c55e2aa5e7abe0349dfbf106 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Fri, 2 May 2025 19:41:33 +0900 Subject: [PATCH 03/14] protocol: unify the client conn --- conformance/client_config.yaml | 5 +- src/connect/client.py | 7 ++- src/connect/connect.py | 41 +++++++----- src/connect/protocol.py | 8 +-- src/connect/protocol_connect.py | 108 ++++++++++++++++++++------------ src/connect/protocol_grpc.py | 7 +-- 6 files changed, 100 insertions(+), 76 deletions(-) diff --git a/conformance/client_config.yaml b/conformance/client_config.yaml index facb36e..22fb4f4 100644 --- a/conformance/client_config.yaml +++ b/conformance/client_config.yaml @@ -3,15 +3,14 @@ features: - HTTP_VERSION_1 - HTTP_VERSION_2 protocols: - # - PROTOCOL_CONNECT - - PROTOCOL_GRPC + - PROTOCOL_CONNECT codecs: - CODEC_PROTO compressions: - COMPRESSION_IDENTITY - COMPRESSION_GZIP stream_types: - # - STREAM_TYPE_UNARY + - STREAM_TYPE_UNARY - STREAM_TYPE_CLIENT_STREAM - STREAM_TYPE_SERVER_STREAM - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM diff --git a/src/connect/client.py b/src/connect/client.py index 451f200..5cdb77a 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -31,6 +31,7 @@ from connect.protocol_connect import ProtocolConnect from connect.protocol_grpc import ProtocolGRPC from connect.session import AsyncClientSession +from connect.utils import aiterate def parse_request_url(raw_url: str) -> URL: @@ -233,9 +234,9 @@ def on_request_send(r: httpcore.Request) -> None: conn.on_request_send(on_request_send) - await conn.send(request.message, request.timeout, abort_event=request.abort_event) + await conn.send(aiterate([request.message]), request.timeout, abort_event=request.abort_event) - response = await recieve_unary_response(conn=conn, t=output) + response = await recieve_unary_response(conn=conn, t=output, abort_event=request.abort_event) return response unary_func = apply_interceptors(_unary_func, options.interceptors) @@ -262,7 +263,7 @@ async def call_unary(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Respon return response async def _stream_func(request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: - conn = protocol_client.stream_conn(request.spec, request.headers) + conn = protocol_client.conn(request.spec, request.headers) def on_request_send(r: httpcore.Request) -> None: method = r.method diff --git a/src/connect/connect.py b/src/connect/connect.py index b550b3a..b58ab00 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -705,18 +705,29 @@ async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> S ) -async def recieve_unary_response[T](conn: UnaryClientConn, t: type[T]) -> UnaryResponse[T]: - """Receive a unary response from a streaming client connection. +async def recieve_unary_response[T]( + conn: StreamingClientConn, t: type[T], abort_event: asyncio.Event | None +) -> UnaryResponse[T]: + """Receives a unary response message from a streaming client connection. + + This asynchronous function waits for a unary message of the specified type from the given + streaming client connection. It also handles optional abortion via an asyncio event. + The response, along with any headers and trailers from the connection, is wrapped in a + UnaryResponse object and returned. Args: - conn (StreamingClientConn): The streaming client connection. - t (type[T]): The type of the expected response message. + conn (StreamingClientConn): The streaming client connection to receive the message from. + t (type[T]): The expected type of the message to be received. + abort_event (asyncio.Event | None): Optional event to signal abortion of the receive operation. Returns: - UnaryResponse[T]: The response containing the message, response headers, and response trailers. + UnaryResponse[T]: The received message and associated response metadata. + + Raises: + Any exceptions raised by `receive_unary_message` or connection errors. """ - message = await receive_unary_message(conn, t) + message = await receive_unary_message(conn, t, abort_event) return UnaryResponse(message, conn.response_headers, conn.response_trailers) @@ -795,22 +806,20 @@ async def recieve_stream_response[T]( return StreamResponse(receive_stream, conn.response_headers, conn.response_trailers) -async def receive_unary_message[T](conn: UnaryClientConn, t: type[T]) -> T: - """Asynchronously receives a single unary message from the given connection. +async def receive_unary_message[T](conn: StreamingClientConn, t: type[T], abort_event: asyncio.Event | None) -> T: + """Receives exactly one unary message of the specified type from a streaming connection. Args: - conn (UnaryClientConn): The unary client connection to receive the message from. - t (type[T]): The expected type of the message to be received. + conn (StreamingClientConn): The streaming client connection to receive the message from. + t (type[T]): The expected type of the message to receive. + abort_event (asyncio.Event | None): An optional event to signal abortion of the receive operation. Returns: - T: The received message of type T. + T: The single message received of type `t`. Raises: - Exception: If receiving the message fails or more than one message is received. - - Note: - This function ensures that exactly one message is received from the connection. + Exception: If zero or more than one message is received, or if the receive operation fails. """ - single_message = await _receive_exactly_one(conn.receive(t), conn.aclose) + single_message = await _receive_exactly_one(conn.receive(t, abort_event), conn.aclose) return single_message diff --git a/src/connect/protocol.py b/src/connect/protocol.py index 94e3e4f..2bce004 100644 --- a/src/connect/protocol.py +++ b/src/connect/protocol.py @@ -15,7 +15,6 @@ StreamingClientConn, StreamingHandlerConn, StreamType, - UnaryClientConn, ) from connect.error import ConnectError from connect.headers import Headers @@ -97,15 +96,10 @@ def write_request_headers(self, stream_type: StreamType, headers: Headers) -> No raise NotImplementedError() @abc.abstractmethod - def conn(self, spec: Spec, headers: Headers) -> UnaryClientConn: + def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: """Return the connection for the client.""" raise NotImplementedError() - @abc.abstractmethod - def stream_conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: - """Return the streaming connection for the client.""" - raise NotImplementedError() - class ProtocolHandler(abc.ABC): """Abstract base class for handling different protocols.""" diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 8ae087c..8334550 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -32,7 +32,6 @@ StreamingClientConn, StreamingHandlerConn, StreamType, - UnaryClientConn, ensure_single, ) from connect.envelope import EnvelopeFlags, EnvelopeReader, EnvelopeWriter @@ -865,7 +864,7 @@ def write_request_headers(self, stream_type: StreamType, headers: Headers) -> No if self.params.compressions: headers[accept_compression_header] = ", ".join(c.name for c in self.params.compressions) - def conn(self, spec: Spec, headers: Headers) -> UnaryClientConn: + def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: """Establish a unary client connection with the given specifications and headers. Args: @@ -876,32 +875,54 @@ def conn(self, spec: Spec, headers: Headers) -> UnaryClientConn: UnaryClientConn: The established unary client connection. """ - conn = ConnectUnaryClientConn( - session=self.params.session, - spec=spec, - peer=self.peer, - url=self.params.url, - compressions=self.params.compressions, - request_headers=headers, - marshaler=ConnectUnaryRequestMarshaler( - connect_marshaler=ConnectUnaryMarshaler( + conn: StreamingClientConn + if spec.stream_type == StreamType.Unary: + conn = ConnectUnaryClientConn( + session=self.params.session, + spec=spec, + peer=self.peer, + url=self.params.url, + compressions=self.params.compressions, + request_headers=headers, + marshaler=ConnectUnaryRequestMarshaler( + connect_marshaler=ConnectUnaryMarshaler( + codec=self.params.codec, + compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + headers=headers, + ) + ), + unmarshaler=ConnectUnaryUnmarshaler( + codec=self.params.codec, + read_max_bytes=self.params.read_max_bytes, + ), + ) + if spec.idempotency_level == IdempotencyLevel.NO_SIDE_EFFECTS: + conn.marshaler.enable_get = self.params.enable_get + conn.marshaler.url = self.params.url + if isinstance(self.params.codec, StableCodec): + conn.marshaler.stable_codec = self.params.codec + else: + conn = ConnectStreamingClientConn( + session=self.params.session, + spec=spec, + peer=self.peer, + url=self.params.url, + codec=self.params.codec, + compressions=self.params.compressions, + request_headers=headers, + marshaler=ConnectStreamingMarshaler( codec=self.params.codec, - compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), compress_min_bytes=self.params.compress_min_bytes, send_max_bytes=self.params.send_max_bytes, - headers=headers, - ) - ), - unmarshaler=ConnectUnaryUnmarshaler( - codec=self.params.codec, - read_max_bytes=self.params.read_max_bytes, - ), - ) - if spec.idempotency_level == IdempotencyLevel.NO_SIDE_EFFECTS: - conn.marshaler.enable_get = self.params.enable_get - conn.marshaler.url = self.params.url - if isinstance(self.params.codec, StableCodec): - conn.marshaler.stable_codec = self.params.codec + compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), + ), + unmarshaler=ConnectStreamingUnmarshaler( + codec=self.params.codec, + read_max_bytes=self.params.read_max_bytes, + ), + ) return conn @@ -1774,7 +1795,7 @@ async def aclose(self) -> None: await self.unmarshaler.aclose() -class ConnectUnaryClientConn(UnaryClientConn): +class ConnectUnaryClientConn(StreamingClientConn): """A client connection for unary RPCs using the Connect protocol. Attributes: @@ -1883,7 +1904,7 @@ async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: obj = await self.unmarshaler.unmarshal(message) yield obj - def receive(self, message: Any) -> AsyncIterator[Any]: + def receive(self, message: Any, _abort_event: asyncio.Event | None) -> AsyncIterator[Any]: """Receives a message and returns an asynchronous iterator over the processed message. Args: @@ -1916,22 +1937,28 @@ def on_request_send(self, fn: EventHook) -> None: """ self._event_hooks["request"].append(fn) - async def send(self, message: Any, timeout: float | None, abort_event: asyncio.Event | None) -> bytes: - """Send a message asynchronously using the specified HTTP method and handles the response. + async def send( + self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None + ) -> None: + """Send a single message asynchronously using either HTTP GET or POST, with support for timeouts and request abortion. Args: - message (Any): The message to be sent, which will be marshaled before sending. - timeout (float | None): The timeout for the request in seconds. If provided, it will be - included in the request headers and extensions. - abort_event (asyncio.Event | None): An optional asyncio event that can be used to abort - the request. If the event is set, the request will be canceled. - - Returns: - bytes: The marshaled data of the message that was sent. + messages (AsyncIterable[Any]): An asynchronous iterable yielding the message(s) to send. Only a single message is allowed. + timeout (float | None): Optional timeout in seconds for the request. If provided, sets a read timeout for the request. + abort_event (asyncio.Event | None): Optional asyncio event that, if set, aborts the request. Raises: - ConnectError: If the request is aborted or if there are issues during the request/response - lifecycle. + ConnectError: If the request is aborted before or during execution, or if other connection errors occur. + + Side Effects: + - Modifies request headers for timeout and content length as needed. + - Invokes registered request and response event hooks. + - Sets the unmarshaler's stream to the response stream for further processing. + - Validates the response after receiving it. + + Notes: + - If `marshaler.enable_get` is True, sends the request as HTTP GET; otherwise, uses HTTP POST. + - Handles cancellation and cleanup if the abort event is triggered during the request. """ if abort_event and abort_event.is_set(): @@ -1942,6 +1969,7 @@ async def send(self, message: Any, timeout: float | None, abort_event: asyncio.E extensions["timeout"] = {"read": timeout} self._request_headers[CONNECT_HEADER_TIMEOUT] = str(int(timeout * 1000)) + message = await ensure_single(messages) data = self.marshaler.marshal(message) if self.marshaler.enable_get: @@ -2015,8 +2043,6 @@ async def send(self, message: Any, timeout: float | None, abort_event: asyncio.E await self._validate_response(response) - return data - @property def response_headers(self) -> Headers: """Return the response headers. diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index c584c4b..c67efcc 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -26,7 +26,6 @@ StreamingClientConn, StreamingHandlerConn, StreamType, - UnaryClientConn, ) from connect.envelope import EnvelopeReader, EnvelopeWriter from connect.error import ConnectError @@ -320,11 +319,7 @@ def write_request_headers(self, stream_type: StreamType, headers: Headers) -> No if not self.web: headers["Te"] = "trailers" - def conn(self, spec: Spec, headers: Headers) -> UnaryClientConn: - """Return the connection for the client.""" - raise NotImplementedError() - - def stream_conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: + def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: """Return the streaming connection for the client.""" return GRPCClientConn( session=self.params.session, From 16a3c716550ad789a29d984be5efb624084d1d70 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 3 May 2025 09:05:14 +0900 Subject: [PATCH 04/14] protocol_grpc: add trailer validation --- conformance/client_config.yaml | 11 +-- conformance/run-testcase.txt | 2 +- src/connect/envelope.py | 5 ++ src/connect/protocol_connect.py | 33 -------- src/connect/protocol_grpc.py | 131 +++++++++++++++++++++++++++++++- 5 files changed, 141 insertions(+), 41 deletions(-) diff --git a/conformance/client_config.yaml b/conformance/client_config.yaml index 22fb4f4..44a7a14 100644 --- a/conformance/client_config.yaml +++ b/conformance/client_config.yaml @@ -3,7 +3,8 @@ features: - HTTP_VERSION_1 - HTTP_VERSION_2 protocols: - - PROTOCOL_CONNECT + # - PROTOCOL_CONNECT + - PROTOCOL_GRPC codecs: - CODEC_PROTO compressions: @@ -11,9 +12,9 @@ features: - COMPRESSION_GZIP stream_types: - STREAM_TYPE_UNARY - - STREAM_TYPE_CLIENT_STREAM - - STREAM_TYPE_SERVER_STREAM - - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM + # - STREAM_TYPE_CLIENT_STREAM + # - STREAM_TYPE_SERVER_STREAM + # - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM # - STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM supports_h2c: true @@ -22,4 +23,4 @@ features: supports_trailers: true supports_half_duplex_bidi_over_http1: true supports_connect_get: true - supports_message_receive_limit: true + supports_message_receive_limit: false diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt index 92362ec..b762a2f 100644 --- a/conformance/run-testcase.txt +++ b/conformance/run-testcase.txt @@ -1 +1 @@ -Basic/HTTPVersion:2/Protocol:PROTOCOL_GRPC/Codec:CODEC_PROTO/Compression:COMPRESSION_IDENTITY/TLS:false/client-stream/success +gRPC Trailers/Compression:COMPRESSION_GZIP/TLS:false/trailers-only/duplicate-metadata diff --git a/src/connect/envelope.py b/src/connect/envelope.py index c29b490..01bf131 100644 --- a/src/connect/envelope.py +++ b/src/connect/envelope.py @@ -252,6 +252,7 @@ class EnvelopeReader: compression: Compression | None stream: AsyncIterable[bytes] | None buffer: bytes + bytes_read: int last_data: bytes | None def __init__( @@ -275,6 +276,7 @@ def __init__( self.compression = compression self.stream = stream self.buffer = b"" + self.bytes_read = 0 self.last_data = None async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: @@ -299,6 +301,7 @@ async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: async for chunk in self.stream: self.buffer += chunk + self.bytes_read += len(chunk) while True: env, data_len = Envelope.decode(self.buffer) @@ -352,6 +355,8 @@ async def aclose(self) -> None: This method checks if the `self.stream` object has an asynchronous `aclose` method. If the method exists, it is invoked to close the stream. + The bytes_read counter is not reset when closing the stream. + Returns: None diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 8334550..9c342d5 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -926,39 +926,6 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: return conn - def stream_conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: - """Establish a streaming connection using the provided specification and headers. - - Args: - spec (Spec): The specification for the streaming connection. - headers (Headers): The headers to be included in the connection request. - - Returns: - StreamingClientConn: An instance of the streaming client connection. - - """ - conn = ConnectStreamingClientConn( - session=self.params.session, - spec=spec, - peer=self.peer, - url=self.params.url, - codec=self.params.codec, - compressions=self.params.compressions, - request_headers=headers, - marshaler=ConnectStreamingMarshaler( - codec=self.params.codec, - compress_min_bytes=self.params.compress_min_bytes, - send_max_bytes=self.params.send_max_bytes, - compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), - ), - unmarshaler=ConnectStreamingUnmarshaler( - codec=self.params.codec, - read_max_bytes=self.params.read_max_bytes, - ), - ) - - return conn - class ConnectUnaryRequestMarshaler: """A class responsible for marshaling unary requests using a provided ConnectUnaryMarshaler. diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index c67efcc..32732fd 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -11,6 +11,7 @@ from typing import Any import httpcore +from google.protobuf.message import DecodeError from google.rpc import status_pb2 from yarl import URL @@ -28,7 +29,7 @@ StreamType, ) from connect.envelope import EnvelopeReader, EnvelopeWriter -from connect.error import ConnectError +from connect.error import ConnectError, ErrorDetail from connect.headers import Headers, include_request_headers from connect.protocol import ( HEADER_CONTENT_TYPE, @@ -40,6 +41,7 @@ ProtocolClientParams, ProtocolHandler, ProtocolHandlerParams, + code_from_http_status, exclude_protocol_headers, negotiate_compression, ) @@ -322,6 +324,7 @@ def write_request_headers(self, stream_type: StreamType, headers: Headers) -> No def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: """Return the streaming connection for the client.""" return GRPCClientConn( + web=self.web, session=self.params.session, spec=spec, peer=self.peer, @@ -432,6 +435,7 @@ async def unmarshal(self, message: Any) -> AsyncIterator[Any]: class GRPCClientConn(StreamingClientConn): + web: bool session: AsyncClientSession _spec: Spec _peer: Peer @@ -447,6 +451,7 @@ class GRPCClientConn(StreamingClientConn): def __init__( self, + web: bool, session: AsyncClientSession, spec: Spec, peer: Peer, @@ -460,6 +465,7 @@ def __init__( ) -> None: event_hooks = {} if event_hooks is None else event_hooks + self.web = web self.session = session self._spec = spec self._peer = peer @@ -499,7 +505,24 @@ async def receive(self, message: Any, abort_event: asyncio.Event | None) -> Asyn if callable(self.receive_trailers): self.receive_trailers() + if self.unmarshaler.bytes_read == 0 and len(self.response_trailers) == 0: + self.response_trailers.update(self._response_headers) + del self._response_headers[HEADER_CONTENT_TYPE] + server_error = grpc_error_from_trailer(self.response_trailers) + if server_error: + server_error.metadata = self.response_headers.copy() + raise server_error + + server_error = grpc_error_from_trailer(self.response_trailers) + if server_error: + server_error.metadata = self.response_headers.copy() + server_error.metadata.update(self.response_trailers) + raise server_error + def _receive_trailers(self, response: httpcore.Response) -> None: + if "trailing_headers" not in response.extensions: + return + trailers = response.extensions["trailing_headers"] self._response_trailers.update(Headers(trailers)) @@ -574,9 +597,22 @@ async def send( async def _validate_response(self, response: httpcore.Response) -> None: response_headers = Headers(response.headers) + if response.status != 200: + raise ConnectError( + f"HTTP {response.status}", + code_from_http_status(response.status), + ) + + grpc_validate_response_content_type( + self.web, + self.marshaler.codec.name if self.marshaler.codec else "", + response_headers.get(HEADER_CONTENT_TYPE, ""), + ) compression = response_headers.get(GRPC_HEADER_COMPRESSION, None) - self.unmarshaler.compression = get_compresion_from_name(compression, self.compressions) + if compression and compression != COMPRESSION_IDENTITY: + self.unmarshaler.compression = get_compresion_from_name(compression, self.compressions) + self._response_headers.update(response_headers) @property @@ -891,3 +927,94 @@ def grpc_content_type_from_codec_name(web: bool, codec_name: str) -> str: return GRPC_CONTENT_TYPE_DEFAULT return GRPC_CONTENT_TYPE_PREFIX + codec_name + + +def grpc_validate_response_content_type(web: bool, request_codec_name: str, response_content_type: str) -> None: + bare, prefix = GRPC_CONTENT_TYPE_DEFAULT, GRPC_CONTENT_TYPE_PREFIX + if web: + bare, prefix = GRPC_WEB_CONTENT_TYPE_DEFAULT, GRPC_WEB_CONTENT_TYPE_PREFIX + + if response_content_type == prefix + request_codec_name or ( + request_codec_name == CodecNameType.PROTO and response_content_type == bare + ): + return + + expected_content_type = bare + if request_codec_name != CodecNameType.PROTO: + expected_content_type = prefix + request_codec_name + + code = Code.INTERNAL + if response_content_type != bare and not response_content_type.startswith(prefix): + code = Code.UNKNOWN + + raise ConnectError(f"invalid content-type {response_content_type}, expected {expected_content_type}", code) + + +def grpc_error_from_trailer(trailers: Headers) -> ConnectError | None: + code_header = trailers.get(GRPC_HEADER_STATUS) + if code_header is None: + code = Code.UNKNOWN + if len(trailers) == 0: + code = Code.INTERNAL + + return ConnectError( + f"protocol error: no {GRPC_HEADER_STATUS} header in trailers", + code, + ) + + if code_header == "0": + return None + + try: + code = Code(int(code_header)) + except ValueError: + return ConnectError( + f"protocol error: invalid error code {code_header} in trailers", + ) + + message = trailers.get(GRPC_HEADER_MESSAGE, None) + if message is None: + return ConnectError( + f"protocol error: invalid error message {code_header} in trailers", + code=Code.UNKNOWN, + ) + + ret_error = ConnectError( + message, + code, + wire_error=True, + ) + + details_binary_encoded = trailers.get(GRPC_HEADER_DETAILS, None) + if details_binary_encoded and len(details_binary_encoded) > 0: + try: + details_binary = decode_binary_header(details_binary_encoded) + except Exception as e: + raise ConnectError( + f"server returned invalid grpc-status-details-bin trailer: {e}", + code=Code.INTERNAL, + ) from e + + status = status_pb2.Status() + try: + status.ParseFromString(details_binary) + except DecodeError as e: + raise ConnectError( + f"server returned invalid protobuf for error details: {e}", + code=Code.INTERNAL, + ) from e + + for detail in status.details: + ret_error.details.append(ErrorDetail(pb_any=detail)) + + ret_error.code = Code(status.code) + ret_error.raw_message = status.message + + return ret_error + + +def decode_binary_header(data: str) -> bytes: + if len(data) % 4: + data += "=" * (-len(data) % 4) + + return base64.b64decode(data, validate=True) From d2a6182e058fd11ebc48d277c6deb3c4250263d2 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 3 May 2025 09:19:10 +0900 Subject: [PATCH 05/14] protocol_grpc: fix grpc message parse --- conformance/run-testcase.txt | 2 +- src/connect/protocol_grpc.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt index b762a2f..54c6ca9 100644 --- a/conformance/run-testcase.txt +++ b/conformance/run-testcase.txt @@ -1 +1 @@ -gRPC Trailers/Compression:COMPRESSION_GZIP/TLS:false/trailers-only/duplicate-metadata +gRPC Unexpected Responses/HTTPVersion:2/TLS:true/trailers-only/ignore-header-if-trailer-present diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 32732fd..63ed758 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -9,6 +9,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Callable, Mapping from http import HTTPMethod from typing import Any +from urllib.parse import unquote import httpcore from google.protobuf.message import DecodeError @@ -972,8 +973,9 @@ def grpc_error_from_trailer(trailers: Headers) -> ConnectError | None: f"protocol error: invalid error code {code_header} in trailers", ) - message = trailers.get(GRPC_HEADER_MESSAGE, None) - if message is None: + try: + message = unquote(trailers.get(GRPC_HEADER_MESSAGE, "")) + except Exception: return ConnectError( f"protocol error: invalid error message {code_header} in trailers", code=Code.UNKNOWN, From 46888546a1054beadfd3b69fd330015bd8ff1f96 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 3 May 2025 09:34:23 +0900 Subject: [PATCH 06/14] protocol_grpc: encode grpc-timeout --- conformance/run-testcase.txt | 2 +- src/connect/protocol_grpc.py | 26 +++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt index 54c6ca9..a71ce69 100644 --- a/conformance/run-testcase.txt +++ b/conformance/run-testcase.txt @@ -1 +1 @@ -gRPC Unexpected Responses/HTTPVersion:2/TLS:true/trailers-only/ignore-header-if-trailer-present +Timeouts/HTTPVersion:2/Protocol:PROTOCOL_GRPC/Codec:CODEC_PROTO/Compression:COMPRESSION_IDENTITY/TLS:true/unary diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 63ed758..37637f6 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -541,7 +541,7 @@ async def send( extensions = {} if timeout: extensions["timeout"] = {"read": timeout} - self._request_headers[GRPC_HEADER_TIMEOUT] = str(int(timeout * 1000)) + self._request_headers[GRPC_HEADER_TIMEOUT] = grpc_encode_timeout(timeout) content_iterator = self.marshaler.marshal(messages) @@ -1020,3 +1020,27 @@ def decode_binary_header(data: str) -> bytes: data += "=" * (-len(data) % 4) return base64.b64decode(data, validate=True) + + +def grpc_encode_timeout(timeout: float) -> str: + if timeout <= 0: + return "0n" + + grpc_timeout_max_value = 10**8 + + _units = ( + (1e-9, "n"), + (1e-6, "u"), + (1e-3, "m"), + (1.0, "S"), + (60.0, "M"), + (3600.0, "H"), + ) + + for size, unit in _units: + if timeout < size * grpc_timeout_max_value: + value = int(timeout / size) + return f"{value}{unit}" + + value = int(timeout / 3600.0) + return f"{value}H" From fcd69ebe6a0ceb23fed2dcddd411766ffc9b75bf Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 3 May 2025 09:38:31 +0900 Subject: [PATCH 07/14] conformance: support streaming --- conformance/client_config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/conformance/client_config.yaml b/conformance/client_config.yaml index 44a7a14..0ad7321 100644 --- a/conformance/client_config.yaml +++ b/conformance/client_config.yaml @@ -3,7 +3,7 @@ features: - HTTP_VERSION_1 - HTTP_VERSION_2 protocols: - # - PROTOCOL_CONNECT + - PROTOCOL_CONNECT - PROTOCOL_GRPC codecs: - CODEC_PROTO @@ -12,9 +12,9 @@ features: - COMPRESSION_GZIP stream_types: - STREAM_TYPE_UNARY - # - STREAM_TYPE_CLIENT_STREAM - # - STREAM_TYPE_SERVER_STREAM - # - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM + - STREAM_TYPE_CLIENT_STREAM + - STREAM_TYPE_SERVER_STREAM + - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM # - STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM supports_h2c: true From 4ce4b114b8b18a44b2f2086451340b36b22b94bf Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sun, 4 May 2025 17:38:45 +0900 Subject: [PATCH 08/14] protocol_grpc: fix grpc timeout dict --- src/connect/protocol_grpc.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 37637f6..a8565c2 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -5,6 +5,7 @@ import contextlib import functools import re +import sys import urllib.parse from collections.abc import AsyncIterable, AsyncIterator, Callable, Mapping from http import HTTPMethod @@ -74,14 +75,14 @@ _RE = re.compile(r"^(\d{1,8})([HMSmun])$") _UNIT_TO_SECONDS = { - "H": 60 * 60, - "M": 60, - "S": 1, - "m": 1e-3, # millisecond - "u": 1e-6, # microsecond "n": 1e-9, # nanosecond + "u": 1e-6, # microsecond + "m": 1e-3, # millisecond + "S": 1.0, + "M": 60.0, + "H": 3600.0, } -_MAX_HOURS = (2**63 - 1) // (60 * 60 * 1_000_000_000) +_MAX_HOURS = sys.maxsize // (60 * 60 * 1_000_000_000) class ProtocolGRPC(Protocol): @@ -716,7 +717,6 @@ def parse_timeout(self) -> float | None: num_str, unit = m.groups() num = int(num_str) - if num > 99_999_999: raise ConnectError(f"protocol error: timeout {timeout!r} is too long") @@ -1028,16 +1028,9 @@ def grpc_encode_timeout(timeout: float) -> str: grpc_timeout_max_value = 10**8 - _units = ( - (1e-9, "n"), - (1e-6, "u"), - (1e-3, "m"), - (1.0, "S"), - (60.0, "M"), - (3600.0, "H"), - ) + _units = dict(sorted(_UNIT_TO_SECONDS.items(), key=lambda item: item[1])) - for size, unit in _units: + for unit, size in _units.items(): if timeout < size * grpc_timeout_max_value: value = int(timeout / size) return f"{value}{unit}" From c14fb7ffa21b62daee2473fb5352502c30fb827a Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sun, 4 May 2025 17:58:22 +0900 Subject: [PATCH 09/14] protocol_grpc: add doc --- src/connect/byte_stream.py | 2 + src/connect/protocol_grpc.py | 207 ++++++++++++++++++++++++++++++++++- 2 files changed, 203 insertions(+), 6 deletions(-) diff --git a/src/connect/byte_stream.py b/src/connect/byte_stream.py index 45bd122..5033d9d 100644 --- a/src/connect/byte_stream.py +++ b/src/connect/byte_stream.py @@ -1,3 +1,5 @@ +"""Asynchronous byte stream utilities for HTTP core response handling.""" + from collections.abc import ( AsyncIterable, AsyncIterator, diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index a8565c2..79ff1e2 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -291,20 +291,62 @@ async def conn( class GRPCClient(ProtocolClient): + """GRPCClient is a protocol client implementation for gRPC communication, supporting both standard and web environments. + + Attributes: + params (ProtocolClientParams): Configuration parameters for the protocol client, including codec, compression, session, and URL. + _peer (Peer): The peer instance associated with this client, representing the remote endpoint. + web (bool): Indicates whether the client is running in a web environment, affecting header and content-type handling. + + """ + params: ProtocolClientParams _peer: Peer web: bool def __init__(self, params: ProtocolClientParams, peer: Peer, web: bool) -> None: + """Initialize the ProtocolClient with the given parameters. + + Args: + params (ProtocolClientParams): The parameters for the protocol client. + peer (Peer): The peer instance to be used. + web (bool): Indicates whether the client is running in a web environment. + + """ self.params = params self._peer = peer self.web = web @property def peer(self) -> Peer: + """Returns the associated Peer object. + + Returns: + Peer: The peer instance associated with this object. + + """ return self._peer - def write_request_headers(self, stream_type: StreamType, headers: Headers) -> None: + def write_request_headers(self, _: StreamType, headers: Headers) -> None: + """Set and modifies HTTP/2 or gRPC request headers based on the stream type, connection parameters, and environment. + + Args: + stream_type (StreamType): The type of stream for which headers are being written. + headers (Headers): The dictionary of headers to be modified or populated. + + Behavior: + - Ensures the 'User-Agent' header is set to the default gRPC user agent if not already present. + - If running in a web environment, also sets the 'X-User-Agent' header. + - Sets the 'Content-Type' header according to the codec name and environment. + - Sets the 'Accept-Encoding' header to indicate supported compression. + - If a specific compression is configured and is not the identity, sets the gRPC compression header. + - If multiple compressions are supported, sets the gRPC accept compression header with the supported values. + - For non-web environments, adds the 'Te: trailers' header required for gRPC. + + Note: + This method mutates the provided headers dictionary in place. + + """ if headers.get(HEADER_USER_AGENT, None) is None: headers[HEADER_USER_AGENT] = DEFAULT_GRPC_USER_AGENT @@ -324,7 +366,21 @@ def write_request_headers(self, stream_type: StreamType, headers: Headers) -> No headers["Te"] = "trailers" def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: - """Return the streaming connection for the client.""" + """Create and returns a GRPCClientConn instance configured with the provided specification and headers. + + Args: + spec (Spec): The specification object defining the protocol or service interface. + headers (Headers): The request headers to include in the connection. + + Returns: + StreamingClientConn: An initialized gRPC streaming client connection. + + Details: + - Configures the connection with parameters such as session, peer, URL, codec, and compression settings. + - Initializes GRPCMarshaler and GRPCUnmarshaler with appropriate codecs and limits. + - Compression is determined using the provided compression name and available compressions. + + """ return GRPCClientConn( web=self.web, session=self.params.session, @@ -437,6 +493,26 @@ async def unmarshal(self, message: Any) -> AsyncIterator[Any]: class GRPCClientConn(StreamingClientConn): + """GRPCClientConn is a gRPC client connection implementation supporting asynchronous streaming requests and responses over HTTP/2. + + This class manages the lifecycle of a gRPC client connection, including marshaling and unmarshaling messages, handling request and response headers/trailers, managing compression, and supporting event hooks for request/response events. It integrates with an asynchronous HTTP client session and supports cancellation via asyncio events. + + Attributes: + session (AsyncClientSession): The asynchronous client session used for HTTP requests. + _spec (Spec): The protocol or API specification. + _peer (Peer): Information about the remote peer. + url (URL): The endpoint URL for the connection. + codec (Codec | None): Codec for encoding/decoding messages. + compressions (list[Compression]): Supported compression algorithms. + marshaler (GRPCMarshaler): Marshaler for serializing messages. + unmarshaler (GRPCUnmarshaler): Unmarshaler for deserializing messages. + _response_headers (Headers): HTTP response headers. + _response_trailers (Headers): HTTP response trailers. + _request_headers (Headers): HTTP request headers. + receive_trailers (Callable[[], None] | None): Callback to receive trailers after response. + + """ + web: bool session: AsyncClientSession _spec: Spec @@ -465,6 +541,22 @@ def __init__( unmarshaler: GRPCUnmarshaler, event_hooks: None | (Mapping[str, list[EventHook]]) = None, ) -> None: + """Initialize a new instance of the class. + + Args: + web (bool): Indicates if the connection is for a web environment. + session (AsyncClientSession): The asynchronous client session to use for requests. + spec (Spec): The specification object describing the protocol or API. + peer (Peer): The peer information for the connection. + url (URL): The URL endpoint for the connection. + codec (Codec | None): The codec to use for encoding/decoding messages, or None. + compressions (list[Compression]): List of supported compression algorithms. + request_headers (Headers): Headers to include in outgoing requests. + marshaler (GRPCMarshaler): The marshaler for serializing messages. + unmarshaler (GRPCUnmarshaler): The unmarshaler for deserializing messages. + event_hooks (None | Mapping[str, list[EventHook]], optional): Optional mapping of event hooks for "request" and "response" events. Defaults to None. + + """ event_hooks = {} if event_hooks is None else event_hooks self.web = web @@ -536,6 +628,22 @@ def request_headers(self) -> Headers: async def send( self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None ) -> None: + """Send a gRPC request asynchronously using HTTP/2 via httpcore, handling streaming messages, timeouts, and abort events. + + Args: + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled and sent as the request body. + timeout (float | None): Optional timeout in seconds for the request. If provided, sets the gRPC timeout header. + abort_event (asyncio.Event | None): Optional asyncio event that, if set, will abort the request and raise a cancellation error. + + Raises: + ConnectError: If the request is aborted before or during execution, or if an error occurs during the HTTP request. + + Side Effects: + - Invokes registered request and response event hooks. + - Sets up the response stream and trailers for further processing. + - Validates the HTTP response. + + """ if abort_event and abort_event.is_set(): raise ConnectError("request aborted", Code.CANCELED) @@ -628,9 +736,23 @@ def response_trailers(self) -> Headers: return self._response_trailers def on_request_send(self, fn: EventHook) -> None: + """Register a callback function to be invoked when a request is sent. + + Args: + fn (EventHook): The callback function to be added to the "request" event hook. + + Returns: + None + + """ self._event_hooks["request"].append(fn) async def aclose(self) -> None: + """Asynchronously closes the underlying unmarshaler resource. + + This method should be called to properly release any resources held by the unmarshaler, + such as open network connections or file handles, when they are no longer needed. + """ await self.unmarshaler.aclose() @@ -909,18 +1031,28 @@ def grpc_error_to_trailer(trailer: Headers, error: ConnectError | None) -> None: ) code = status.code message = status.message - bin = None + details_binary = None if len(status.details) > 0: - bin = status.SerializeToString() + details_binary = status.SerializeToString() trailer[GRPC_HEADER_STATUS] = str(code) trailer[GRPC_HEADER_MESSAGE] = urllib.parse.quote(message) - if bin: - trailer[GRPC_HEADER_DETAILS] = base64.b64encode(bin).decode().rstrip("=") + if details_binary: + trailer[GRPC_HEADER_DETAILS] = base64.b64encode(details_binary).decode().rstrip("=") def grpc_content_type_from_codec_name(web: bool, codec_name: str) -> str: + """Return the appropriate gRPC content type string based on the given codec name and whether the request is for gRPC-Web. + + Args: + web (bool): Indicates if the content type is for gRPC-Web (True) or standard gRPC (False). + codec_name (str): The name of the codec (e.g., "proto", "json"). + + Returns: + str: The corresponding gRPC content type string. + + """ if web: return GRPC_WEB_CONTENT_TYPE_PREFIX + codec_name @@ -931,6 +1063,17 @@ def grpc_content_type_from_codec_name(web: bool, codec_name: str) -> str: def grpc_validate_response_content_type(web: bool, request_codec_name: str, response_content_type: str) -> None: + """Validate that the gRPC response content type matches the expected value based on the request codec and whether gRPC-Web is used. + + Args: + web (bool): Indicates if gRPC-Web is being used. + request_codec_name (str): The name of the codec used in the request (e.g., "proto", "json"). + response_content_type (str): The content type returned in the response. + + Raises: + ConnectError: If the response content type does not match the expected value, with an appropriate error code. + + """ bare, prefix = GRPC_CONTENT_TYPE_DEFAULT, GRPC_CONTENT_TYPE_PREFIX if web: bare, prefix = GRPC_WEB_CONTENT_TYPE_DEFAULT, GRPC_WEB_CONTENT_TYPE_PREFIX @@ -952,6 +1095,24 @@ def grpc_validate_response_content_type(web: bool, request_codec_name: str, resp def grpc_error_from_trailer(trailers: Headers) -> ConnectError | None: + """Parse gRPC error information from response trailers and constructs a ConnectError if present. + + Args: + trailers (Headers): The gRPC response trailers containing error information. + + Returns: + ConnectError | None: Returns a ConnectError instance if an error is found in the trailers, + or None if the status code indicates success. + + Raises: + ConnectError: If the grpc-status-details-bin trailer or protobuf error details are invalid. + + The function extracts the gRPC status code, error message, and optional error details from the trailers. + If the status code is missing or invalid, it returns a ConnectError with an appropriate message. + If the status code indicates success ("0"), it returns None. + If error details are present and valid, they are attached to the ConnectError. + + """ code_header = trailers.get(GRPC_HEADER_STATUS) if code_header is None: code = Code.UNKNOWN @@ -1016,6 +1177,21 @@ def grpc_error_from_trailer(trailers: Headers) -> ConnectError | None: def decode_binary_header(data: str) -> bytes: + """Decode a base64-encoded string representing a binary header. + + If the input string's length is not a multiple of 4, it pads the string with '=' characters + to make it valid base64 before decoding. + + Args: + data (str): The base64-encoded string to decode. + + Returns: + bytes: The decoded binary data. + + Raises: + binascii.Error: If the input is not correctly base64-encoded. + + """ if len(data) % 4: data += "=" * (-len(data) % 4) @@ -1023,6 +1199,25 @@ def decode_binary_header(data: str) -> bytes: def grpc_encode_timeout(timeout: float) -> str: + """Encode a timeout value (in seconds) into the gRPC timeout format string. + + The gRPC timeout format is a decimal number with a time unit suffix, where the unit can be: + - 'H' for hours + - 'M' for minutes + - 'S' for seconds + - 'm' for milliseconds + - 'u' for microseconds + - 'n' for nanoseconds + + If the timeout is less than or equal to zero, returns "0n". + + Args: + timeout (float): The timeout value in seconds. + + Returns: + str: The timeout encoded as a gRPC timeout string. + + """ if timeout <= 0: return "0n" From 33af74cb08adaf1dc5827cdcb3e233452d28e9a7 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sun, 4 May 2025 18:28:41 +0900 Subject: [PATCH 10/14] connect: refactor ensure_single func --- src/connect/connect.py | 107 +++++++++-------------------------------- 1 file changed, 24 insertions(+), 83 deletions(-) diff --git a/src/connect/connect.py b/src/connect/connect.py index b58ab00..1a89692 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -361,51 +361,25 @@ async def __aiter__(self) -> AsyncIterator[T]: """ if self.stream_type == StreamType.Unary or self.stream_type == StreamType.ServerStream: - async for item in validate_single_content_stream(self._iterable): - yield item + item = await ensure_single(self._iterable) + yield item else: async for item in self._iterable: yield item -async def validate_single_content_stream[T](iterable: AsyncIterable[T]) -> AsyncIterator[T]: - """Validate that an asynchronous iterable yields exactly one item. - - This async generator iterates over the provided `iterable` and ensures that it produces exactly one item. - If more than one item is yielded, a `ConnectError` is raised indicating a protocol error. - If no items are yielded, a `ConnectError` is also raised. - - Args: - iterable (AsyncIterable[T]): The asynchronous iterable to validate. - - Yields: - T: The single item from the iterable. - - Raises: - ConnectError: If the iterable yields zero or more than one item. - - """ - count = 0 - async for item in iterable: - if count > 0: - raise ConnectError("protocol error: expected only one message, but got multiple", Code.UNIMPLEMENTED) - - yield item - count += 1 - - if count == 0: - raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) - - -async def ensure_single[T](iterable: AsyncIterable[T]) -> T: +async def ensure_single[T](iterable: AsyncIterable[T], aclose: Callable[[], Awaitable[None]] | None = None) -> T: """Asynchronously ensures that the given async iterable yields exactly one item. Iterates over the provided async iterable (after validating its content stream) and returns the single item if present. Raises a ConnectError if the iterable - is empty or contains more than one item. + is empty or contains more than one item. Optionally closes the iterable by calling + the provided aclose function after processing. Args: iterable (AsyncIterable[T]): An asynchronous iterable expected to yield exactly one item. + aclose (Callable[[], Awaitable[None]] | None, optional): A callable that asynchronously + closes the stream when invoked. If provided, will be called in a finally block. Returns: T: The single item yielded by the iterable. @@ -414,14 +388,20 @@ async def ensure_single[T](iterable: AsyncIterable[T]) -> T: ConnectError: If the iterable yields no items or more than one item. """ - message = None - async for item in validate_single_content_stream(iterable): - message = item - - if message is None: - raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) - - return message + try: + iterator = iterable.__aiter__() + try: + first = await iterator.__anext__() + try: + await iterator.__anext__() + raise ConnectError("protocol error: expected only one message, but got multiple", Code.UNIMPLEMENTED) + except StopAsyncIteration: + return first + except StopAsyncIteration: + raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) from None + finally: + if aclose: + await aclose() class StreamingHandlerConn(abc.ABC): @@ -732,44 +712,6 @@ async def recieve_unary_response[T]( return UnaryResponse(message, conn.response_headers, conn.response_trailers) -async def _receive_exactly_one[T](stream: AsyncIterator[T], aclose: Callable[[], Awaitable[None]]) -> T: - """Asynchronously receives exactly one item from an asynchronous iterator. - - This function ensures that the provided asynchronous iterator (`stream`) yields - exactly one item. If the iterator yields no items or more than one item, a - `ConnectError` is raised. The provided `aclose` callable is always invoked to - close the stream, regardless of success or failure. - - Type Parameters: - T: The type of the items in the asynchronous iterator. - - Args: - stream (AsyncIterator[T]): The asynchronous iterator to consume. - aclose (Callable[[], Awaitable[None]]): A callable that closes the stream - when invoked. - - Returns: - T: The single item yielded by the asynchronous iterator. - - Raises: - ConnectError: If the iterator yields no items or more than one item. - - """ - try: - first = await stream.__anext__() - try: - await stream.__anext__() - raise ConnectError( - "ClientStream should only receive one message, but received multiple.", Code.UNIMPLEMENTED - ) - except StopAsyncIteration: - return first - except StopAsyncIteration: - raise ConnectError("ClientStream should receive one message, but received none.", Code.UNIMPLEMENTED) from None - finally: - await aclose() - - async def recieve_stream_response[T]( conn: StreamingClientConn, t: type[T], spec: Spec, abort_event: asyncio.Event | None ) -> StreamResponse[T]: @@ -797,7 +739,7 @@ async def recieve_stream_response[T]( receive_stream = AsyncDataStream[T](conn.receive(t, abort_event), conn.aclose) if spec.stream_type == StreamType.ClientStream: - single_message = await _receive_exactly_one(receive_stream.__aiter__(), receive_stream.aclose) + single_message = await ensure_single(receive_stream, receive_stream.aclose) return StreamResponse( AsyncDataStream[T](aiterate([single_message])), conn.response_headers, conn.response_trailers @@ -818,8 +760,7 @@ async def receive_unary_message[T](conn: StreamingClientConn, t: type[T], abort_ T: The single message received of type `t`. Raises: - Exception: If zero or more than one message is received, or if the receive operation fails. + ConnectError: If zero or more than one message is received, or if the receive operation fails. """ - single_message = await _receive_exactly_one(conn.receive(t, abort_event), conn.aclose) - return single_message + return await ensure_single(conn.receive(t, abort_event), conn.aclose) From 2962fd77242d6cb665515f078d43a40a1b8bd5ba Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sun, 4 May 2025 18:54:21 +0900 Subject: [PATCH 11/14] connect: refacter recv stream --- src/connect/connect.py | 107 +++++++------------------------- src/connect/handler.py | 8 +++ src/connect/protocol_connect.py | 15 ++--- src/connect/protocol_grpc.py | 5 +- 4 files changed, 38 insertions(+), 97 deletions(-) diff --git a/src/connect/connect.py b/src/connect/connect.py index 1a89692..090e287 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -318,56 +318,6 @@ async def aclose(self) -> None: await aclose() -class AsyncContentStream[T](AsyncIterable[T]): - """AsyncContentStream is a generic asynchronous stream wrapper for async iterables, providing validation and iteration utilities based on stream type. - - Type Parameters: - T: The type of elements yielded by the asynchronous iterable. - - iterable (AsyncIterable[T]): The asynchronous iterable to wrap. - stream_type (StreamType): The type of stream (e.g., Unary, ServerStream) that determines validation behavior. - - Attributes: - _iterable (AsyncIterable[T]): The underlying asynchronous iterable. - stream_type (StreamType): The type of stream this instance represents. - - """ - - def __init__(self, iterable: AsyncIterable[T], stream_type: StreamType) -> None: - """Initialize a stream wrapper for an async iterable. - - This constructor stores the provided async iterable and its corresponding - stream type for later processing. - - Args: - iterable: An asynchronous iterable containing elements of type T. - stream_type: The type of stream this iterable represents. - - Returns: - None - - """ - self._iterable = iterable - self.stream_type = stream_type - - async def __aiter__(self) -> AsyncIterator[T]: - """Asynchronously iterates over the underlying iterable. - - If single message validation is required, wraps the iterable with a validation step. - Otherwise, yields items directly from the iterable. - - Yields: - T: Items from the underlying asynchronous iterable. - - """ - if self.stream_type == StreamType.Unary or self.stream_type == StreamType.ServerStream: - item = await ensure_single(self._iterable) - yield item - else: - async for item in self._iterable: - yield item - - async def ensure_single[T](iterable: AsyncIterable[T], aclose: Callable[[], Awaitable[None]] | None = None) -> T: """Asynchronously ensures that the given async iterable yields exactly one item. @@ -442,7 +392,7 @@ def peer(self) -> Peer: raise NotImplementedError() @abc.abstractmethod - def receive(self, message: Any) -> AsyncContentStream[Any]: + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message and returns an asynchronous content stream. Args: @@ -674,15 +624,24 @@ async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> S peer information, request headers, and HTTP method. """ - stream = conn.receive(t) - - return StreamRequest( - messages=stream, - spec=conn.spec, - peer=conn.peer, - headers=conn.request_headers, - method=HTTPMethod.POST, - ) + if conn.spec.stream_type == StreamType.ServerStream: + message = await ensure_single(conn.receive(t)) + + return StreamRequest( + messages=AsyncDataStream[T](aiterate([message])), + spec=conn.spec, + peer=conn.peer, + headers=conn.request_headers, + method=HTTPMethod.POST.value, + ) + else: + return StreamRequest( + messages=AsyncDataStream[T](conn.receive(t)), + spec=conn.spec, + peer=conn.peer, + headers=conn.request_headers, + method=HTTPMethod.POST.value, + ) async def recieve_unary_response[T]( @@ -707,7 +666,7 @@ async def recieve_unary_response[T]( Any exceptions raised by `receive_unary_message` or connection errors. """ - message = await receive_unary_message(conn, t, abort_event) + message = await ensure_single(conn.receive(t, abort_event), conn.aclose) return UnaryResponse(message, conn.response_headers, conn.response_trailers) @@ -736,31 +695,13 @@ async def recieve_stream_response[T]( - For other stream types, it directly returns the received stream. """ - receive_stream = AsyncDataStream[T](conn.receive(t, abort_event), conn.aclose) - if spec.stream_type == StreamType.ClientStream: - single_message = await ensure_single(receive_stream, receive_stream.aclose) + single_message = await ensure_single(conn.receive(t, abort_event), conn.aclose) return StreamResponse( AsyncDataStream[T](aiterate([single_message])), conn.response_headers, conn.response_trailers ) else: - return StreamResponse(receive_stream, conn.response_headers, conn.response_trailers) - - -async def receive_unary_message[T](conn: StreamingClientConn, t: type[T], abort_event: asyncio.Event | None) -> T: - """Receives exactly one unary message of the specified type from a streaming connection. - - Args: - conn (StreamingClientConn): The streaming client connection to receive the message from. - t (type[T]): The expected type of the message to receive. - abort_event (asyncio.Event | None): An optional event to signal abortion of the receive operation. - - Returns: - T: The single message received of type `t`. - - Raises: - ConnectError: If zero or more than one message is received, or if the receive operation fails. - - """ - return await ensure_single(conn.receive(t, abort_event), conn.aclose) + return StreamResponse( + AsyncDataStream[T](conn.receive(t, abort_event), conn.aclose), conn.response_headers, conn.response_trailers + ) diff --git a/src/connect/handler.py b/src/connect/handler.py index a68849a..cdb51ce 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -336,6 +336,13 @@ async def stream_handle( except Exception as e: error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) + + if isinstance(e, TimeoutError): + error = ConnectError("the operation timed out", Code.DEADLINE_EXCEEDED) + + if isinstance(e, NotImplementedError): + error = ConnectError("not implemented", Code.UNIMPLEMENTED) + await conn.send_error(error) async def unary_handle( @@ -378,6 +385,7 @@ async def unary_handle( if isinstance(e, NotImplementedError): error = ConnectError("not implemented", Code.UNIMPLEMENTED) + await conn.send_error(error) diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 9c342d5..dc963fa 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -26,7 +26,6 @@ from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name from connect.connect import ( Address, - AsyncContentStream, Peer, Spec, StreamingClientConn, @@ -682,7 +681,7 @@ async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: obj = await self.unmarshaler.unmarshal(message) yield obj - def receive(self, message: Any) -> AsyncContentStream[Any]: + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message, unmarshals it, and returns the resulting object. Args: @@ -692,10 +691,7 @@ def receive(self, message: Any) -> AsyncContentStream[Any]: AsyncIterator[Any]: An async iterator yielding the unmarshaled object. """ - return AsyncContentStream( - self._receive_messages(message), - stream_type=self.spec.stream_type, - ) + return self._receive_messages(message) @property def request_headers(self) -> Headers: @@ -1337,7 +1333,7 @@ async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: async for obj, _ in self.unmarshaler.unmarshal(message): yield obj - def receive(self, message: Any) -> AsyncContentStream[Any]: + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message and returns an asynchronous content stream. This method processes the incoming message through the receive_message method @@ -1351,10 +1347,7 @@ def receive(self, message: Any) -> AsyncContentStream[Any]: processed message, configured with the specification's stream type. """ - return AsyncContentStream( - iterable=self._receive_messages(message), - stream_type=self.spec.stream_type, - ) + return self._receive_messages(message) @property def request_headers(self) -> Headers: diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 79ff1e2..9c61bee 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -23,7 +23,6 @@ from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name from connect.connect import ( Address, - AsyncContentStream, Peer, Spec, StreamingClientConn, @@ -868,7 +867,7 @@ def peer(self) -> Peer: """ return self._peer - def receive(self, message: Any) -> AsyncContentStream[Any]: + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message and processes it. Args: @@ -879,7 +878,7 @@ def receive(self, message: Any) -> AsyncContentStream[Any]: this will yield exactly one item. """ - return AsyncContentStream(self.unmarshaler.unmarshal(message), self.spec.stream_type) + return self.unmarshaler.unmarshal(message) @property def request_headers(self) -> Headers: From a9a89895cf8d9dcaa55d78be242d2a3fea4bf595 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sun, 4 May 2025 18:55:46 +0900 Subject: [PATCH 12/14] client_runnler: remove sleep for client request --- conformance/client_runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/conformance/client_runner.py b/conformance/client_runner.py index adb18cb..2722937 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -429,8 +429,6 @@ async def read_requests() -> None: """Read requests from standard input and process them asynchronously.""" loop = asyncio.get_event_loop() while req := await loop.run_in_executor(None, read_request): - await asyncio.sleep(0.01) - loop.create_task(run_message(req)) asyncio.run(read_requests()) From 066f4f2a95640bb286006c6e8ae8a69f3a716ed0 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sun, 4 May 2025 19:01:00 +0900 Subject: [PATCH 13/14] client: fix grpc and grpc-web option --- conformance/client_runner.py | 4 ++-- src/connect/client.py | 4 ++-- src/connect/options.py | 9 +++------ 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/conformance/client_runner.py b/conformance/client_runner.py index 2722937..6e1fa50 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -219,9 +219,9 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c try: options = ClientOptions() if msg.protocol == config_pb2.PROTOCOL_GRPC: - options.grpc = True + options.protocol = "grpc" if msg.protocol == config_pb2.PROTOCOL_GRPC_WEB: - options.grpc = True + options.protocol = "grpc-web" if msg.compression == config_pb2.COMPRESSION_GZIP: options.request_compression_name = "gzip" diff --git a/src/connect/client.py b/src/connect/client.py index 5cdb77a..913fc26 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -115,9 +115,9 @@ def __init__(self, raw_url: str, options: ClientOptions): self.url = url self.protocol = ProtocolConnect() - if options.grpc: + if options.protocol == "grpc": self.protocol = ProtocolGRPC(web=False) - if options.grpc_web: + elif options.protocol == "grpc-web": self.protocol = ProtocolGRPC(web=True) self.procedure = proto_path self.codec = ProtoBinaryCodec() diff --git a/src/connect/options.py b/src/connect/options.py index 94f1190..ef7eba6 100644 --- a/src/connect/options.py +++ b/src/connect/options.py @@ -1,6 +1,6 @@ """Options for the UniversalHandler class.""" -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field @@ -84,11 +84,8 @@ class ClientOptions(BaseModel): enable_get: bool = Field(default=False) """A boolean indicating whether to enable GET requests.""" - grpc: bool = Field(default=False) - """A boolean indicating whether to use gRPC.""" - - grpc_web: bool = Field(default=False) - """A boolean indicating whether to use gRPC-Web.""" + protocol: Literal["connect", "grpc", "grpc-web"] = Field(default="connect") + """The protocol to use for the request.""" def merge(self, override_options: "ClientOptions | None" = None) -> "ClientOptions": """Merge this options object with an override options object. From bad55e396d62ca822122853580a6763781f92aa7 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sun, 4 May 2025 19:10:39 +0900 Subject: [PATCH 14/14] connect: fix stream request type --- src/connect/connect.py | 4 ++-- src/connect/protocol_grpc.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/connect/connect.py b/src/connect/connect.py index 090e287..b4fd760 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -628,7 +628,7 @@ async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> S message = await ensure_single(conn.receive(t)) return StreamRequest( - messages=AsyncDataStream[T](aiterate([message])), + messages=aiterate([message]), spec=conn.spec, peer=conn.peer, headers=conn.request_headers, @@ -636,7 +636,7 @@ async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> S ) else: return StreamRequest( - messages=AsyncDataStream[T](conn.receive(t)), + messages=conn.receive(t), spec=conn.spec, peer=conn.peer, headers=conn.request_headers, diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 9c61bee..a7e02e7 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -49,7 +49,7 @@ from connect.request import Request from connect.session import AsyncClientSession from connect.streaming_response import StreamingResponse -from connect.utils import aiterate, map_httpcore_exceptions +from connect.utils import map_httpcore_exceptions from connect.version import __version__ from connect.writer import ServerResponseWriter @@ -567,7 +567,6 @@ def __init__( self.compressions = compressions self.marshaler = marshaler self.unmarshaler = unmarshaler - self.response_content = None self._response_headers = Headers() self._response_trailers = Headers() self._request_headers = request_headers @@ -601,6 +600,7 @@ async def receive(self, message: Any, abort_event: asyncio.Event | None) -> Asyn if self.unmarshaler.bytes_read == 0 and len(self.response_trailers) == 0: self.response_trailers.update(self._response_headers) del self._response_headers[HEADER_CONTENT_TYPE] + server_error = grpc_error_from_trailer(self.response_trailers) if server_error: server_error.metadata = self.response_headers.copy() @@ -969,7 +969,10 @@ async def send_error(self, error: ConnectError) -> None: await self.writer.write( StreamingResponse( - content=aiterate([b""]), headers=self.response_headers, trailers=self.response_trailers, status_code=200 + content=[], + headers=self.response_headers, + trailers=self.response_trailers, + status_code=200, ) ) @@ -1223,7 +1226,6 @@ def grpc_encode_timeout(timeout: float) -> str: grpc_timeout_max_value = 10**8 _units = dict(sorted(_UNIT_TO_SECONDS.items(), key=lambda item: item[1])) - for unit, size in _units.items(): if timeout < size * grpc_timeout_max_value: value = int(timeout / size)