Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conformance/run-testcase.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Timeouts/HTTPVersion:2/Protocol:PROTOCOL_GRPC/Codec:CODEC_PROTO/Compression:COMPRESSION_IDENTITY/TLS:true/unary
Duplicate Metadata/HTTPVersion:1/Protocol:PROTOCOL_GRPC_WEB/Codec:CODEC_PROTO/Compression:COMPRESSION_IDENTITY/TLS:false/(grpc server impl)/bidi-stream/half-duplex/error-with-responses
130 changes: 69 additions & 61 deletions src/connect/protocol_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,9 @@ async def conn(

if query_params.get(CONNECT_UNARY_BASE64_QUERY_PARAMETER) == "1":
message_unquoted = unquote(message)
missing_padding = len(message_unquoted) % 4
if missing_padding:
message_unquoted += "=" * (4 - missing_padding)

decoded = base64.urlsafe_b64decode(message_unquoted)
decoded = base64.urlsafe_b64decode(message_unquoted + "=" * (-len(message_unquoted) % 4))
else:
decoded = message.encode("utf-8")
decoded = message.encode()

request_stream = aiterate([decoded])
codec_name = encoding
Expand Down Expand Up @@ -678,8 +674,7 @@ async def _receive_messages(self, message: Any) -> AsyncIterator[Any]:
AsyncIterator[Any]: An async iterator yielding the unmarshaled object.

"""
obj = await self.unmarshaler.unmarshal(message)
yield obj
yield await self.unmarshaler.unmarshal(message)

def receive(self, message: Any) -> AsyncIterator[Any]:
"""Receives a message, unmarshals it, and returns the resulting object.
Expand Down Expand Up @@ -881,13 +876,11 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn:
compressions=self.params.compressions,
request_headers=headers,
marshaler=ConnectUnaryRequestMarshaler(
connect_marshaler=ConnectUnaryMarshaler(
codec=self.params.codec,
compression=get_compresion_from_name(self.params.compression_name, self.params.compressions),
compress_min_bytes=self.params.compress_min_bytes,
send_max_bytes=self.params.send_max_bytes,
headers=headers,
)
codec=self.params.codec,
compression=get_compresion_from_name(self.params.compression_name, self.params.compressions),
compress_min_bytes=self.params.compress_min_bytes,
send_max_bytes=self.params.send_max_bytes,
headers=headers,
),
unmarshaler=ConnectUnaryUnmarshaler(
codec=self.params.codec,
Expand Down Expand Up @@ -923,36 +916,53 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn:
return conn


class ConnectUnaryRequestMarshaler:
"""A class responsible for marshaling unary requests using a provided ConnectUnaryMarshaler.
class ConnectUnaryRequestMarshaler(ConnectUnaryMarshaler):
"""ConnectUnaryRequestMarshaler is responsible for marshaling unary request messages for the Connect protocol, with support for GET requests and stable codecs.

This class extends ConnectUnaryMarshaler to provide additional functionality for handling GET requests,
including marshaling messages using a stable codec, enforcing message size limits, and optionally compressing
messages when necessary. It also manages the construction of GET URLs with appropriate query parameters and
headers for the Connect protocol.

Attributes:
connect_marshaler (ConnectUnaryMarshaler): An instance of ConnectUnaryMarshaler used to marshal messages.
enable_get (bool): Flag indicating whether GET requests are enabled.
stable_codec (StableCodec | None): The codec used for stable marshaling, if available.
url (URL | None): The URL to use for the request.

"""

connect_marshaler: ConnectUnaryMarshaler
enable_get: bool
stable_codec: StableCodec | None
url: URL | None

def __init__(
self,
connect_marshaler: ConnectUnaryMarshaler,
codec: Codec | None,
compression: Compression | None,
compress_min_bytes: int,
send_max_bytes: int,
headers: Headers,
enable_get: bool = False,
stable_codec: StableCodec | None = None,
url: URL | None = None,
) -> None:
"""Initialize the ProtocolConnect instance.
"""Initialize the protocol connection with the specified configuration.

Args:
connect_marshaler (ConnectUnaryMarshaler): The marshaler used for connecting.
enable_get (bool, optional): Flag to enable GET requests. Defaults to False.
stable_codec (StableCodec | None, optional): The codec to use for stable connections. Defaults to None.
url (URL | None, optional): The URL for the connection. Defaults to None.
codec (Codec | None): The codec to use for encoding/decoding messages, or None.
compression (Compression | None): The compression algorithm to use, or None.
compress_min_bytes (int): Minimum number of bytes before compression is applied.
send_max_bytes (int): Maximum number of bytes allowed per send operation.
headers (Headers): Headers to include in each request.
enable_get (bool, optional): Whether to enable GET requests. Defaults to False.
stable_codec (StableCodec | None, optional): An optional stable codec for message encoding/decoding. Defaults to None.
url (URL | None, optional): The URL endpoint for the connection. Defaults to None.

Returns:
None

"""
self.connect_marshaler = connect_marshaler
super().__init__(codec, compression, compress_min_bytes, send_max_bytes, headers)
self.enable_get = enable_get
self.stable_codec = stable_codec
self.url = url
Expand All @@ -965,7 +975,7 @@ def marshal(self, message: Any) -> bytes:
Otherwise, if `enable_get` is True and `stable_codec` is not None, marshals
the message using the `marshal_with_get` method.

If `enable_get` is False, marshals the message using the `connect_marshaler`.
If `enable_get` is False, marshals the message using the `.

Args:
message (Any): The message to be marshaled.
Expand All @@ -978,18 +988,18 @@ def marshal(self, message: Any) -> bytes:

"""
if self.enable_get:
if self.connect_marshaler.codec is None:
if self.codec is None:
raise ConnectError("codec is not set", Code.INTERNAL)

if self.stable_codec is None:
raise ConnectError(
f"codec {self.connect_marshaler.codec.name} doesn't support stable marshal; can't use get",
f"codec {self.codec.name} doesn't support stable marshal; can't use get",
Code.INTERNAL,
)
else:
return self.marshal_with_get(message)

return self.connect_marshaler.marshal(message)
return super().marshal(message)

def marshal_with_get(self, message: Any) -> bytes:
"""Marshals the given message and sends it using a GET request.
Expand All @@ -1016,14 +1026,15 @@ def marshal_with_get(self, message: Any) -> bytes:
limit.

"""
assert self.stable_codec is not None
if self.stable_codec is None:
raise ConnectError("stable_codec is not set", Code.INTERNAL)

data = self.stable_codec.marshal_stable(message)

is_too_big = self.connect_marshaler.send_max_bytes > 0 and len(data) > self.connect_marshaler.send_max_bytes
if is_too_big and not self.connect_marshaler.compression:
is_too_big = self.send_max_bytes > 0 and len(data) > self.send_max_bytes
if is_too_big and not self.compression:
raise ConnectError(
f"message size {len(data)} exceeds sendMaxBytes {self.connect_marshaler.send_max_bytes}: enabling request compression may help",
f"message size {len(data)} exceeds sendMaxBytes {self.send_max_bytes}: enabling request compression may help",
Code.RESOURCE_EXHAUSTED,
)

Expand All @@ -1033,12 +1044,12 @@ def marshal_with_get(self, message: Any) -> bytes:
self._write_with_get(url)
return data

assert self.connect_marshaler.compression
data = self.connect_marshaler.compression.compress(data)
if self.compression:
data = self.compression.compress(data)

if self.connect_marshaler.send_max_bytes > 0 and len(data) > self.connect_marshaler.send_max_bytes:
if self.send_max_bytes > 0 and len(data) > self.send_max_bytes:
raise ConnectError(
f"compressed message size {len(data)} exceeds send_max_bytes {self.connect_marshaler.send_max_bytes}",
f"compressed message size {len(data)} exceeds send_max_bytes {self.send_max_bytes}",
Code.RESOURCE_EXHAUSTED,
)

Expand All @@ -1048,16 +1059,16 @@ def marshal_with_get(self, message: Any) -> bytes:
return data

def _build_get_url(self, data: bytes, compressed: bool) -> URL:
assert self.url is not None
assert self.stable_codec is not None
if self.url is None or self.stable_codec is None:
raise ConnectError("url or stable_codec is not set", Code.INTERNAL)

if self.connect_marshaler.codec is None:
if self.codec is None:
raise ConnectError("codec is not set", Code.INTERNAL)

url = self.url
url = url.update_query({
CONNECT_UNARY_CONNECT_QUERY_PARAMETER: CONNECT_UNARY_CONNECT_QUERY_VALUE,
CONNECT_UNARY_ENCODING_QUERY_PARAMETER: self.connect_marshaler.codec.name,
CONNECT_UNARY_ENCODING_QUERY_PARAMETER: self.codec.name,
})
if self.stable_codec.is_binary() or compressed:
url = url.update_query({
Expand All @@ -1070,22 +1081,22 @@ def _build_get_url(self, data: bytes, compressed: bool) -> URL:
})

if compressed:
if not self.connect_marshaler.compression:
if not self.compression:
raise ConnectError(
"compression must be set for compressed message",
Code.INTERNAL,
)

url = url.update_query({CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER: self.connect_marshaler.compression.name})
url = url.update_query({CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER: self.compression.name})

return url

def _write_with_get(self, url: URL) -> None:
with contextlib.suppress(Exception):
del self.connect_marshaler.headers[CONNECT_HEADER_PROTOCOL_VERSION]
del self.connect_marshaler.headers[HEADER_CONTENT_TYPE]
del self.connect_marshaler.headers[HEADER_CONTENT_ENCODING]
del self.connect_marshaler.headers[HEADER_CONTENT_LENGTH]
del self.headers[CONNECT_HEADER_PROTOCOL_VERSION]
del self.headers[HEADER_CONTENT_TYPE]
del self.headers[HEADER_CONTENT_ENCODING]
del self.headers[HEADER_CONTENT_LENGTH]

self.url = url

Expand Down Expand Up @@ -1646,9 +1657,6 @@ async def send(
- The response stream is unmarshaled and validated after the request is completed.

"""
if abort_event and abort_event.is_set():
raise ConnectError("request aborted", Code.CANCELED)

extensions = {}
if timeout:
extensions["timeout"] = {"read": timeout}
Expand Down Expand Up @@ -1710,7 +1718,10 @@ async def _validate_response(self, response: httpcore.Response) -> None:
response_headers = Headers(response.headers)

if response.status != HTTPStatus.OK:
await response.aread()
try:
await response.aread()
finally:
await response.aclose()

raise ConnectError(
f"HTTP {response.status}",
Expand Down Expand Up @@ -1921,9 +1932,6 @@ async def send(
- Handles cancellation and cleanup if the abort event is triggered during the request.

"""
if abort_event and abort_event.is_set():
raise ConnectError("request aborted", Code.CANCELED)

extensions = {}
if timeout:
extensions["timeout"] = {"read": timeout}
Expand All @@ -1933,7 +1941,8 @@ async def send(
data = self.marshaler.marshal(message)

if self.marshaler.enable_get:
assert self.marshaler.url is not None
if self.marshaler.url is None:
raise ConnectError("url is not set", Code.INTERNAL)

request = httpcore.Request(
method=HTTPMethod.GET,
Expand Down Expand Up @@ -2036,7 +2045,7 @@ async def _validate_response(self, response: httpcore.Response) -> None:
self._response_trailers[key[len(CONNECT_UNARY_TRAILER_PREFIX) :]] = value

validate_error = connect_validate_unary_response_content_type(
self.marshaler.connect_marshaler.codec.name if self.marshaler.connect_marshaler.codec else "",
self.marshaler.codec.name if self.marshaler.codec else "",
response.status,
self._response_headers.get(HEADER_CONTENT_TYPE, ""),
)
Expand Down Expand Up @@ -2459,7 +2468,7 @@ def error_to_json(error: ConnectError) -> dict[str, Any]:
for detail in error.details:
wire: dict[str, Any] = {
"type": detail.pb_any.TypeName(),
"value": base64.b64encode(detail.pb_any.value).decode("utf-8").rstrip("="),
"value": base64.b64encode(detail.pb_any.value).decode().rstrip("="),
}

with contextlib.suppress(Exception):
Expand Down Expand Up @@ -2491,7 +2500,6 @@ def error_to_json_bytes(error: ConnectError) -> bytes:
json_obj = error_to_json(error)
json_str = json.dumps(json_obj)

return json_str.encode("utf-8")
return json_str.encode()
except Exception as e:
message = str(e)
raise ConnectError(f"failed to serialize Connect Error: {message}", Code.INTERNAL) from e
raise ConnectError(f"failed to serialize Connect Error: {e}", Code.INTERNAL) from e
4 changes: 1 addition & 3 deletions src/connect/protocol_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]:
env = self.last
if not env:
raise ConnectError("protocol error: empty envelope")

data = copy(env.data)
env.data = b""

Expand Down Expand Up @@ -731,9 +732,6 @@ async def send(
- Validates the HTTP response.

"""
if abort_event and abort_event.is_set():
raise ConnectError("request aborted", Code.CANCELED)

extensions = {}
if timeout:
extensions["timeout"] = {"read": timeout}
Expand Down
1 change: 0 additions & 1 deletion src/connect/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ async def aclose(self) -> None:
"""
if self.aclose_func:
await self.aclose_func()
return


async def aiterate[T](iterable: typing.Iterable[T]) -> typing.AsyncIterator[T]:
Expand Down
Loading