Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
472 changes: 236 additions & 236 deletions conformance/uv.lock

Large diffs are not rendered by default.

49 changes: 49 additions & 0 deletions src/connect/protocol_connect/base64_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Base64 utilities for Connect protocol."""

import base64


def decode_base64_with_padding(value: str) -> bytes:
"""Decode base64 string with proper padding.

Args:
value: Base64 encoded string that may be missing padding

Returns:
Decoded bytes

Raises:
Exception: If base64 decoding fails
"""
# Add padding if needed
padded_value = value + "=" * (-len(value) % 4)
return base64.b64decode(padded_value.encode())


def decode_urlsafe_base64_with_padding(value: str) -> bytes:
"""Decode URL-safe base64 string with proper padding.

Args:
value: URL-safe base64 encoded string that may be missing padding

Returns:
Decoded bytes

Raises:
Exception: If base64 decoding fails
"""
# Add padding if needed
padded_value = value + "=" * (-len(value) % 4)
return base64.urlsafe_b64decode(padded_value)


def encode_base64_without_padding(data: bytes) -> str:
"""Encode bytes to base64 string without padding.

Args:
data: Bytes to encode

Returns:
Base64 encoded string with padding removed
"""
return base64.b64encode(data).decode().rstrip("=")
50 changes: 30 additions & 20 deletions src/connect/protocol_connect/connect_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,20 +417,25 @@ async def send(
request_task = asyncio.create_task(self.pool.handle_async_request(request=request))
abort_task = asyncio.create_task(abort_event.wait())

done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED)
try:
done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED)

if abort_task in done:
request_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await request_task
if abort_task in done:
request_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await request_task

raise ConnectError("request aborted", Code.CANCELED)
raise ConnectError("request aborted", Code.CANCELED)

abort_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await abort_task
abort_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await abort_task

response = await request_task
response = await request_task
finally:
for task in [request_task, abort_task]:
if not task.done():
task.cancel()

for hook in self._event_hooks["response"]:
hook(response)
Expand Down Expand Up @@ -777,20 +782,25 @@ async def send(
request_task = asyncio.create_task(self.pool.handle_async_request(request=request))
abort_task = asyncio.create_task(abort_event.wait())

done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED)
try:
done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED)

if abort_task in done:
request_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await request_task
if abort_task in done:
request_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await request_task

raise ConnectError("request aborted", Code.CANCELED)
raise ConnectError("request aborted", Code.CANCELED)

abort_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await abort_task
abort_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await abort_task

response = await request_task
response = await request_task
finally:
for task in [request_task, abort_task]:
if not task.done():
task.cancel()

for hook in self._event_hooks["response"]:
hook(response)
Expand Down
6 changes: 3 additions & 3 deletions src/connect/protocol_connect/connect_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Provides a ConnectHander class for handling connection protocols."""
"""Provides a ConnectHandler class for handling connection protocols."""

import base64
import json
from collections.abc import (
AsyncIterable,
Expand Down Expand Up @@ -30,6 +29,7 @@
exclude_protocol_headers,
negotiate_compression,
)
from connect.protocol_connect.base64_utils import decode_urlsafe_base64_with_padding
from connect.protocol_connect.constants import (
CONNECT_HEADER_PROTOCOL_VERSION,
CONNECT_HEADER_TIMEOUT,
Expand Down Expand Up @@ -173,7 +173,7 @@ async def conn(

if query_params.get(CONNECT_UNARY_BASE64_QUERY_PARAMETER) == "1":
message_unquoted = unquote(message)
decoded = base64.urlsafe_b64decode(message_unquoted + "=" * (-len(message_unquoted) % 4))
decoded = decode_urlsafe_base64_with_padding(message_unquoted)
else:
decoded = message.encode()

Expand Down
5 changes: 3 additions & 2 deletions src/connect/protocol_connect/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Constants used in the Connect protocol implementation for Python."""

from sys import version
import sys

from connect.version import __version__

Expand All @@ -24,4 +24,5 @@
CONNECT_UNARY_CONNECT_QUERY_PARAMETER = "connect"
CONNECT_UNARY_CONNECT_QUERY_VALUE = "v" + CONNECT_PROTOCOL_VERSION

DEFAULT_CONNECT_USER_AGENT = f"connect-python/{__version__} (Python/{version})"
_python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
DEFAULT_CONNECT_USER_AGENT = f"connect-python/{__version__} (Python/{_python_version})"
19 changes: 10 additions & 9 deletions src/connect/protocol_connect/end_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,21 @@ def end_stream_from_bytes(data: bytes) -> tuple[ConnectError | None, Headers]:

metadata = Headers()
if "metadata" in obj:
if not isinstance(obj["metadata"], dict) or not all(
isinstance(k, str) and isinstance(v, list) for k, v in obj["metadata"].items()
metadata_obj = obj["metadata"]
if not isinstance(metadata_obj, dict) or not all(
isinstance(k, str) and isinstance(v, list) for k, v in metadata_obj.items()
):
raise ConnectError(
"invalid end stream",
Code.UNKNOWN,
)

for key, values in obj["metadata"].items():
value = ", ".join(values)
metadata[key] = value
for key, values in metadata_obj.items():
metadata[key] = ", ".join(values)

if "error" in obj and obj["error"] is not None:
error = error_from_json(obj["error"], parse_error)
error_obj = obj.get("error")
if error_obj is not None:
error = error_from_json(error_obj, parse_error)
return error, metadata
else:
return None, metadata

return None, metadata
21 changes: 15 additions & 6 deletions src/connect/protocol_connect/error_json.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""Module for serializing and deserializing ConnectError objects to and from JSON."""

import base64
import contextlib
import json
import threading
from typing import Any

import google.protobuf.any_pb2 as any_pb2
from google.protobuf import json_format

from connect.code import Code
from connect.error import DEFAULT_ANY_RESOLVER_PREFIX, ConnectError, ErrorDetail
from connect.protocol_connect.base64_utils import decode_base64_with_padding, encode_base64_without_padding

_string_to_code: dict[str, Code] | None = None
_code_mapping_lock = threading.Lock()


def code_to_string(value: Code) -> str:
Expand Down Expand Up @@ -41,6 +43,8 @@ def code_from_string(value: str) -> Code | None:
If the cache is not initialized, it populates the cache by iterating over all Code enum values
and mapping their string representations to the corresponding Code enum.

This function is thread-safe and ensures the global cache is initialized only once.

Args:
value (str): The string representation of the code.

Expand All @@ -50,10 +54,15 @@ def code_from_string(value: str) -> Code | None:
"""
global _string_to_code

# Double-checked locking pattern for thread safety
if _string_to_code is None:
_string_to_code = {}
for code in Code:
_string_to_code[code_to_string(code)] = code
with _code_mapping_lock:
# Check again after acquiring the lock
if _string_to_code is None:
temp_mapping = {}
for code in Code:
temp_mapping[code_to_string(code)] = code
_string_to_code = temp_mapping

return _string_to_code.get(value)

Expand Down Expand Up @@ -94,7 +103,7 @@ def error_from_json(obj: dict[str, Any], fallback: ConnectError) -> ConnectError

type_name = type_name if "/" in type_name else DEFAULT_ANY_RESOLVER_PREFIX + type_name
try:
decoded = base64.b64decode(value.encode() + b"=" * (4 - len(value) % 4))
decoded = decode_base64_with_padding(value)
except Exception as e:
raise fallback from e

Expand Down Expand Up @@ -132,7 +141,7 @@ def error_to_json(error: ConnectError) -> dict[str, Any]:
for detail in error.details:
wire: dict[str, Any] = {
"type": detail.pb_any.TypeName(),
"value": base64.b64encode(detail.pb_any.value).decode().rstrip("="),
"value": encode_base64_without_padding(detail.pb_any.value),
}

with contextlib.suppress(Exception):
Expand Down
5 changes: 3 additions & 2 deletions src/connect/protocol_connect/marshaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,14 @@ def _build_get_url(self, data: bytes, compressed: bool) -> URL:
CONNECT_UNARY_ENCODING_QUERY_PARAMETER: self.codec.name,
})
if self.stable_codec.is_binary() or compressed:
encoded_data = base64.urlsafe_b64encode(data).decode().rstrip("=")
url = url.update_query({
CONNECT_UNARY_MESSAGE_QUERY_PARAMETER: base64.urlsafe_b64encode(data).rstrip(b"=").decode("utf-8"),
CONNECT_UNARY_MESSAGE_QUERY_PARAMETER: encoded_data,
CONNECT_UNARY_BASE64_QUERY_PARAMETER: "1",
})
else:
url = url.update_query({
CONNECT_UNARY_MESSAGE_QUERY_PARAMETER: data.decode("utf-8"),
CONNECT_UNARY_MESSAGE_QUERY_PARAMETER: data.decode(),
})

if compressed:
Expand Down
Loading