From d235b33bbdb34644fc4528e979a53b71aae312c8 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Mon, 12 May 2025 21:56:42 +0900 Subject: [PATCH 1/3] protocol_connect: small fix --- src/connect/protocol_connect.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index f7a7c60..4151c85 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -234,13 +234,9 @@ async def conn( if query_params.get(CONNECT_UNARY_BASE64_QUERY_PARAMETER) == "1": message_unquoted = unquote(message) - missing_padding = len(message_unquoted) % 4 - if missing_padding: - message_unquoted += "=" * (4 - missing_padding) - - decoded = base64.urlsafe_b64decode(message_unquoted) + decoded = base64.urlsafe_b64decode(message_unquoted + "=" * (-len(message_unquoted) % 4)) else: - decoded = message.encode("utf-8") + decoded = message.encode() request_stream = aiterate([decoded]) codec_name = encoding @@ -678,8 +674,7 @@ async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: AsyncIterator[Any]: An async iterator yielding the unmarshaled object. """ - obj = await self.unmarshaler.unmarshal(message) - yield obj + yield await self.unmarshaler.unmarshal(message) def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message, unmarshals it, and returns the resulting object. @@ -2459,7 +2454,7 @@ def error_to_json(error: ConnectError) -> dict[str, Any]: for detail in error.details: wire: dict[str, Any] = { "type": detail.pb_any.TypeName(), - "value": base64.b64encode(detail.pb_any.value).decode("utf-8").rstrip("="), + "value": base64.b64encode(detail.pb_any.value).decode().rstrip("="), } with contextlib.suppress(Exception): @@ -2491,7 +2486,6 @@ def error_to_json_bytes(error: ConnectError) -> bytes: json_obj = error_to_json(error) json_str = json.dumps(json_obj) - return json_str.encode("utf-8") + return json_str.encode() except Exception as e: - message = str(e) - raise ConnectError(f"failed to serialize Connect Error: {message}", Code.INTERNAL) from e + raise ConnectError(f"failed to serialize Connect Error: {e}", Code.INTERNAL) from e From dc15ed1be0601a84abdd5268a8a3c199332c05e8 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Mon, 12 May 2025 22:12:36 +0900 Subject: [PATCH 2/3] protocol_connect: fix connect marshaler --- src/connect/protocol_connect.py | 99 +++++++++++++++++++-------------- 1 file changed, 57 insertions(+), 42 deletions(-) diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 4151c85..393c4a2 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -876,13 +876,11 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: compressions=self.params.compressions, request_headers=headers, marshaler=ConnectUnaryRequestMarshaler( - connect_marshaler=ConnectUnaryMarshaler( - codec=self.params.codec, - compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), - compress_min_bytes=self.params.compress_min_bytes, - send_max_bytes=self.params.send_max_bytes, - headers=headers, - ) + codec=self.params.codec, + compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + headers=headers, ), unmarshaler=ConnectUnaryUnmarshaler( codec=self.params.codec, @@ -918,36 +916,53 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: return conn -class ConnectUnaryRequestMarshaler: - """A class responsible for marshaling unary requests using a provided ConnectUnaryMarshaler. +class ConnectUnaryRequestMarshaler(ConnectUnaryMarshaler): + """ConnectUnaryRequestMarshaler is responsible for marshaling unary request messages for the Connect protocol, with support for GET requests and stable codecs. + + This class extends ConnectUnaryMarshaler to provide additional functionality for handling GET requests, + including marshaling messages using a stable codec, enforcing message size limits, and optionally compressing + messages when necessary. It also manages the construction of GET URLs with appropriate query parameters and + headers for the Connect protocol. Attributes: - connect_marshaler (ConnectUnaryMarshaler): An instance of ConnectUnaryMarshaler used to marshal messages. + enable_get (bool): Flag indicating whether GET requests are enabled. + stable_codec (StableCodec | None): The codec used for stable marshaling, if available. + url (URL | None): The URL to use for the request. """ - connect_marshaler: ConnectUnaryMarshaler enable_get: bool stable_codec: StableCodec | None url: URL | None def __init__( self, - connect_marshaler: ConnectUnaryMarshaler, + codec: Codec | None, + compression: Compression | None, + compress_min_bytes: int, + send_max_bytes: int, + headers: Headers, enable_get: bool = False, stable_codec: StableCodec | None = None, url: URL | None = None, ) -> None: - """Initialize the ProtocolConnect instance. + """Initialize the protocol connection with the specified configuration. Args: - connect_marshaler (ConnectUnaryMarshaler): The marshaler used for connecting. - enable_get (bool, optional): Flag to enable GET requests. Defaults to False. - stable_codec (StableCodec | None, optional): The codec to use for stable connections. Defaults to None. - url (URL | None, optional): The URL for the connection. Defaults to None. + codec (Codec | None): The codec to use for encoding/decoding messages, or None. + compression (Compression | None): The compression algorithm to use, or None. + compress_min_bytes (int): Minimum number of bytes before compression is applied. + send_max_bytes (int): Maximum number of bytes allowed per send operation. + headers (Headers): Headers to include in each request. + enable_get (bool, optional): Whether to enable GET requests. Defaults to False. + stable_codec (StableCodec | None, optional): An optional stable codec for message encoding/decoding. Defaults to None. + url (URL | None, optional): The URL endpoint for the connection. Defaults to None. + + Returns: + None """ - self.connect_marshaler = connect_marshaler + super().__init__(codec, compression, compress_min_bytes, send_max_bytes, headers) self.enable_get = enable_get self.stable_codec = stable_codec self.url = url @@ -960,7 +975,7 @@ def marshal(self, message: Any) -> bytes: Otherwise, if `enable_get` is True and `stable_codec` is not None, marshals the message using the `marshal_with_get` method. - If `enable_get` is False, marshals the message using the `connect_marshaler`. + If `enable_get` is False, marshals the message using the `. Args: message (Any): The message to be marshaled. @@ -973,18 +988,18 @@ def marshal(self, message: Any) -> bytes: """ if self.enable_get: - if self.connect_marshaler.codec is None: + if self.codec is None: raise ConnectError("codec is not set", Code.INTERNAL) if self.stable_codec is None: raise ConnectError( - f"codec {self.connect_marshaler.codec.name} doesn't support stable marshal; can't use get", + f"codec {self.codec.name} doesn't support stable marshal; can't use get", Code.INTERNAL, ) else: return self.marshal_with_get(message) - return self.connect_marshaler.marshal(message) + return super().marshal(message) def marshal_with_get(self, message: Any) -> bytes: """Marshals the given message and sends it using a GET request. @@ -1011,14 +1026,15 @@ def marshal_with_get(self, message: Any) -> bytes: limit. """ - assert self.stable_codec is not None + if self.stable_codec is None: + raise ConnectError("stable_codec is not set", Code.INTERNAL) data = self.stable_codec.marshal_stable(message) - is_too_big = self.connect_marshaler.send_max_bytes > 0 and len(data) > self.connect_marshaler.send_max_bytes - if is_too_big and not self.connect_marshaler.compression: + is_too_big = self.send_max_bytes > 0 and len(data) > self.send_max_bytes + if is_too_big and not self.compression: raise ConnectError( - f"message size {len(data)} exceeds sendMaxBytes {self.connect_marshaler.send_max_bytes}: enabling request compression may help", + f"message size {len(data)} exceeds sendMaxBytes {self.send_max_bytes}: enabling request compression may help", Code.RESOURCE_EXHAUSTED, ) @@ -1028,12 +1044,12 @@ def marshal_with_get(self, message: Any) -> bytes: self._write_with_get(url) return data - assert self.connect_marshaler.compression - data = self.connect_marshaler.compression.compress(data) + assert self.compression + data = self.compression.compress(data) - if self.connect_marshaler.send_max_bytes > 0 and len(data) > self.connect_marshaler.send_max_bytes: + if self.send_max_bytes > 0 and len(data) > self.send_max_bytes: raise ConnectError( - f"compressed message size {len(data)} exceeds send_max_bytes {self.connect_marshaler.send_max_bytes}", + f"compressed message size {len(data)} exceeds send_max_bytes {self.send_max_bytes}", Code.RESOURCE_EXHAUSTED, ) @@ -1043,16 +1059,16 @@ def marshal_with_get(self, message: Any) -> bytes: return data def _build_get_url(self, data: bytes, compressed: bool) -> URL: - assert self.url is not None - assert self.stable_codec is not None + if self.url is None or self.stable_codec is None: + raise ConnectError("url or stable_codec is not set", Code.INTERNAL) - if self.connect_marshaler.codec is None: + if self.codec is None: raise ConnectError("codec is not set", Code.INTERNAL) url = self.url url = url.update_query({ CONNECT_UNARY_CONNECT_QUERY_PARAMETER: CONNECT_UNARY_CONNECT_QUERY_VALUE, - CONNECT_UNARY_ENCODING_QUERY_PARAMETER: self.connect_marshaler.codec.name, + CONNECT_UNARY_ENCODING_QUERY_PARAMETER: self.codec.name, }) if self.stable_codec.is_binary() or compressed: url = url.update_query({ @@ -1065,22 +1081,22 @@ def _build_get_url(self, data: bytes, compressed: bool) -> URL: }) if compressed: - if not self.connect_marshaler.compression: + if not self.compression: raise ConnectError( "compression must be set for compressed message", Code.INTERNAL, ) - url = url.update_query({CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER: self.connect_marshaler.compression.name}) + url = url.update_query({CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER: self.compression.name}) return url def _write_with_get(self, url: URL) -> None: with contextlib.suppress(Exception): - del self.connect_marshaler.headers[CONNECT_HEADER_PROTOCOL_VERSION] - del self.connect_marshaler.headers[HEADER_CONTENT_TYPE] - del self.connect_marshaler.headers[HEADER_CONTENT_ENCODING] - del self.connect_marshaler.headers[HEADER_CONTENT_LENGTH] + del self.headers[CONNECT_HEADER_PROTOCOL_VERSION] + del self.headers[HEADER_CONTENT_TYPE] + del self.headers[HEADER_CONTENT_ENCODING] + del self.headers[HEADER_CONTENT_LENGTH] self.url = url @@ -1367,7 +1383,6 @@ async def _send_messages(self, messages: AsyncIterable[Any]) -> AsyncIterator[by bytes: Each marshaled message followed by an end stream message """ - error: ConnectError | None = None try: async for message in self.marshaler.marshal(messages): yield message @@ -2031,7 +2046,7 @@ async def _validate_response(self, response: httpcore.Response) -> None: self._response_trailers[key[len(CONNECT_UNARY_TRAILER_PREFIX) :]] = value validate_error = connect_validate_unary_response_content_type( - self.marshaler.connect_marshaler.codec.name if self.marshaler.connect_marshaler.codec else "", + self.marshaler.codec.name if self.marshaler.codec else "", response.status, self._response_headers.get(HEADER_CONTENT_TYPE, ""), ) From 82d397f488d80f2b08e2ba8ecc6991977b596769 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Tue, 13 May 2025 01:03:38 +0900 Subject: [PATCH 3/3] protocol_connect: fix --- conformance/run-testcase.txt | 2 +- src/connect/protocol_connect.py | 19 +++++++++---------- src/connect/protocol_grpc.py | 4 +--- src/connect/utils.py | 1 - 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt index a71ce69..718089b 100644 --- a/conformance/run-testcase.txt +++ b/conformance/run-testcase.txt @@ -1 +1 @@ -Timeouts/HTTPVersion:2/Protocol:PROTOCOL_GRPC/Codec:CODEC_PROTO/Compression:COMPRESSION_IDENTITY/TLS:true/unary +Duplicate Metadata/HTTPVersion:1/Protocol:PROTOCOL_GRPC_WEB/Codec:CODEC_PROTO/Compression:COMPRESSION_IDENTITY/TLS:false/(grpc server impl)/bidi-stream/half-duplex/error-with-responses diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 393c4a2..ddd273d 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -1044,8 +1044,8 @@ def marshal_with_get(self, message: Any) -> bytes: self._write_with_get(url) return data - assert self.compression - data = self.compression.compress(data) + if self.compression: + data = self.compression.compress(data) if self.send_max_bytes > 0 and len(data) > self.send_max_bytes: raise ConnectError( @@ -1383,6 +1383,7 @@ async def _send_messages(self, messages: AsyncIterable[Any]) -> AsyncIterator[by bytes: Each marshaled message followed by an end stream message """ + error: ConnectError | None = None try: async for message in self.marshaler.marshal(messages): yield message @@ -1656,9 +1657,6 @@ async def send( - The response stream is unmarshaled and validated after the request is completed. """ - if abort_event and abort_event.is_set(): - raise ConnectError("request aborted", Code.CANCELED) - extensions = {} if timeout: extensions["timeout"] = {"read": timeout} @@ -1720,7 +1718,10 @@ async def _validate_response(self, response: httpcore.Response) -> None: response_headers = Headers(response.headers) if response.status != HTTPStatus.OK: - await response.aread() + try: + await response.aread() + finally: + await response.aclose() raise ConnectError( f"HTTP {response.status}", @@ -1931,9 +1932,6 @@ async def send( - Handles cancellation and cleanup if the abort event is triggered during the request. """ - if abort_event and abort_event.is_set(): - raise ConnectError("request aborted", Code.CANCELED) - extensions = {} if timeout: extensions["timeout"] = {"read": timeout} @@ -1943,7 +1941,8 @@ async def send( data = self.marshaler.marshal(message) if self.marshaler.enable_get: - assert self.marshaler.url is not None + if self.marshaler.url is None: + raise ConnectError("url is not set", Code.INTERNAL) request = httpcore.Request( method=HTTPMethod.GET, diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 750c659..7c0251c 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -518,6 +518,7 @@ async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: env = self.last if not env: raise ConnectError("protocol error: empty envelope") + data = copy(env.data) env.data = b"" @@ -731,9 +732,6 @@ async def send( - Validates the HTTP response. """ - if abort_event and abort_event.is_set(): - raise ConnectError("request aborted", Code.CANCELED) - extensions = {} if timeout: extensions["timeout"] = {"read": timeout} diff --git a/src/connect/utils.py b/src/connect/utils.py index 1bab7b7..3e71cb5 100644 --- a/src/connect/utils.py +++ b/src/connect/utils.py @@ -297,7 +297,6 @@ async def aclose(self) -> None: """ if self.aclose_func: await self.aclose_func() - return async def aiterate[T](iterable: typing.Iterable[T]) -> typing.AsyncIterator[T]: