From c0683874bf2755dd7a51a6a66722f642314769e4 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Tue, 13 May 2025 17:10:27 +0900 Subject: [PATCH 1/3] protocol_connect: refactor protocol_connect file --- src/connect/client.py | 2 +- src/connect/handler.py | 2 +- src/connect/protocol_connect.py | 2505 ----------------- src/connect/protocol_connect/__init__.py | 0 .../protocol_connect/connect_client.py | 851 ++++++ .../protocol_connect/connect_handler.py | 772 +++++ .../protocol_connect/connect_protocol.py | 59 + src/connect/protocol_connect/constants.py | 25 + src/connect/protocol_connect/content_type.py | 103 + src/connect/protocol_connect/end_stream.py | 75 + src/connect/protocol_connect/error_code.py | 67 + src/connect/protocol_connect/error_json.py | 143 + src/connect/protocol_connect/marshaler.py | 351 +++ src/connect/protocol_connect/unmarshaler.py | 214 ++ 14 files changed, 2662 insertions(+), 2507 deletions(-) delete mode 100644 src/connect/protocol_connect.py create mode 100644 src/connect/protocol_connect/__init__.py create mode 100644 src/connect/protocol_connect/connect_client.py create mode 100644 src/connect/protocol_connect/connect_handler.py create mode 100644 src/connect/protocol_connect/connect_protocol.py create mode 100644 src/connect/protocol_connect/constants.py create mode 100644 src/connect/protocol_connect/content_type.py create mode 100644 src/connect/protocol_connect/end_stream.py create mode 100644 src/connect/protocol_connect/error_code.py create mode 100644 src/connect/protocol_connect/error_json.py create mode 100644 src/connect/protocol_connect/marshaler.py create mode 100644 src/connect/protocol_connect/unmarshaler.py diff --git a/src/connect/client.py b/src/connect/client.py index 913fc26..3fa27e4 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -28,7 +28,7 @@ from connect.interceptor import apply_interceptors from connect.options import ClientOptions from connect.protocol import Protocol, ProtocolClient, ProtocolClientParams -from connect.protocol_connect import ProtocolConnect +from connect.protocol_connect.connect_protocol import ProtocolConnect from connect.protocol_grpc import ProtocolGRPC from connect.session import AsyncClientSession from connect.utils import aiterate diff --git a/src/connect/handler.py b/src/connect/handler.py index ee1f9c0..8494ab6 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -37,7 +37,7 @@ sorted_accept_post_value, sorted_allow_method_value, ) -from connect.protocol_connect import ( +from connect.protocol_connect.connect_protocol import ( ProtocolConnect, ) from connect.protocol_grpc import ProtocolGRPC diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py deleted file mode 100644 index ddd273d..0000000 --- a/src/connect/protocol_connect.py +++ /dev/null @@ -1,2505 +0,0 @@ -"""Provides classes and functions for handling protocol connections.""" - -import asyncio -import base64 -import contextlib -import json -from collections.abc import ( - AsyncIterable, - AsyncIterator, - Callable, - Mapping, -) -from http import HTTPMethod, HTTPStatus -from sys import version -from typing import Any -from urllib.parse import unquote - -import google.protobuf.any_pb2 as any_pb2 -import httpcore -from google.protobuf import json_format -from yarl import URL - -from connect.byte_stream import HTTPCoreResponseAsyncByteStream -from connect.code import Code -from connect.codec import Codec, CodecNameType, StableCodec -from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name -from connect.connect import ( - Address, - Peer, - Spec, - StreamingClientConn, - StreamingHandlerConn, - StreamType, - ensure_single, -) -from connect.envelope import EnvelopeFlags, EnvelopeReader, EnvelopeWriter -from connect.error import DEFAULT_ANY_RESOLVER_PREFIX, ConnectError, ErrorDetail -from connect.headers import Headers, include_request_headers -from connect.idempotency_level import IdempotencyLevel -from connect.protocol import ( - HEADER_CONTENT_ENCODING, - HEADER_CONTENT_LENGTH, - HEADER_CONTENT_TYPE, - HEADER_USER_AGENT, - PROTOCOL_CONNECT, - Protocol, - ProtocolClient, - ProtocolClientParams, - ProtocolHandler, - ProtocolHandlerParams, - code_from_http_status, - exclude_protocol_headers, - negotiate_compression, -) -from connect.request import Request -from connect.response import Response -from connect.session import AsyncClientSession -from connect.streaming_response import StreamingResponse -from connect.utils import ( - aiterate, - get_acallable_attribute, - map_httpcore_exceptions, -) -from connect.version import __version__ -from connect.writer import ServerResponseWriter - -CONNECT_UNARY_HEADER_COMPRESSION = "Content-Encoding" -CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION = "Accept-Encoding" -CONNECT_UNARY_TRAILER_PREFIX = "Trailer-" -CONNECT_STREAMING_HEADER_COMPRESSION = "Connect-Content-Encoding" -CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION = "Connect-Accept-Encoding" -CONNECT_HEADER_TIMEOUT = "Connect-Timeout-Ms" -CONNECT_HEADER_PROTOCOL_VERSION = "Connect-Protocol-Version" -CONNECT_PROTOCOL_VERSION = "1" - -CONNECT_UNARY_CONTENT_TYPE_PREFIX = "application/" -CONNECT_UNARY_CONTENT_TYPE_JSON = "application/json" -CONNECT_STREAMING_CONTENT_TYPE_PREFIX = "application/connect+" - -CONNECT_UNARY_ENCODING_QUERY_PARAMETER = "encoding" -CONNECT_UNARY_MESSAGE_QUERY_PARAMETER = "message" -CONNECT_UNARY_BASE64_QUERY_PARAMETER = "base64" -CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER = "compression" -CONNECT_UNARY_CONNECT_QUERY_PARAMETER = "connect" -CONNECT_UNARY_CONNECT_QUERY_VALUE = "v" + CONNECT_PROTOCOL_VERSION - -DEFAULT_CONNECT_USER_AGENT = f"connect-python/{__version__} (Python/{version})" - - -def connect_codec_from_content_type(stream_type: StreamType, content_type: str) -> str: - """Extract the codec from the content type based on the stream type. - - Args: - stream_type (StreamType): The type of stream (Unary or Streaming). - content_type (str): The content type string from which to extract the codec. - - Returns: - str: The extracted codec from the content type. - - """ - if stream_type == StreamType.Unary: - return content_type[len(CONNECT_UNARY_CONTENT_TYPE_PREFIX) :] - - return content_type[len(CONNECT_STREAMING_CONTENT_TYPE_PREFIX) :] - - -def connect_content_type_from_codec_name(stream_type: StreamType, codec_name: str) -> str: - """Generate the content type string for a given stream type and codec name. - - Args: - stream_type (StreamType): The type of the stream (e.g., Unary or Streaming). - codec_name (str): The name of the codec. - - Returns: - str: The content type string constructed from the stream type and codec name. - - """ - if stream_type == StreamType.Unary: - return CONNECT_UNARY_CONTENT_TYPE_PREFIX + codec_name - - return CONNECT_STREAMING_CONTENT_TYPE_PREFIX + codec_name - - -class ConnectHandler(ProtocolHandler): - """A handler for managing protocol connections. - - Attributes: - params (ProtocolHandlerParams): Parameters for the protocol handler. - __methods (list[HTTPMethod]): List of HTTP methods supported by the handler. - accept (list[str]): List of accepted content types. - - """ - - params: ProtocolHandlerParams - _methods: list[HTTPMethod] - accept: list[str] - - def __init__(self, params: ProtocolHandlerParams, methods: list[HTTPMethod], accept: list[str]) -> None: - """Initialize the ProtocolConnect instance. - - Args: - params (ProtocolHandlerParams): The parameters for the protocol handler. - methods (list[HTTPMethod]): A list of HTTP methods. - accept (list[str]): A list of accepted content types. - - """ - self.params = params - self._methods = methods - self.accept = accept - - @property - def methods(self) -> list[HTTPMethod]: - """Return the list of HTTP methods. - - Returns: - list[HTTPMethod]: A list of HTTP methods. - - """ - return self._methods - - def content_types(self) -> list[str]: - """Handle content types. - - This method currently does nothing and serves as a placeholder for future - implementation related to content types. - - """ - return self.accept - - def can_handle_payload(self, request: Request, content_type: str) -> bool: - """Check if the handler can handle the payload.""" - if HTTPMethod(request.method) == HTTPMethod.GET: - codec_name = request.query_params.get(CONNECT_UNARY_ENCODING_QUERY_PARAMETER, "") - content_type = connect_content_type_from_codec_name(self.params.spec.stream_type, codec_name) - - return content_type in self.accept - - async def conn( - self, - request: Request, - response_headers: Headers, - response_trailers: Headers, - writer: ServerResponseWriter, - ) -> StreamingHandlerConn | None: - """Handle a connection request. - - Args: - request (Request): The incoming request object. - response_headers (Headers): The headers to be sent in the response. - response_trailers (Headers): The trailers to be sent in the response. - writer (ServerResponseWriter): The writer used to send the response. - is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. - - Returns: - StreamingHandlerConn | None: The connection handler or None if not implemented. - - Raises: - ConnectError: If there is an error in negotiating compression, protocol version, or message encoding. - - """ - query_params = request.query_params - - if self.params.spec.stream_type == StreamType.Unary: - if HTTPMethod(request.method) == HTTPMethod.GET: - content_encoding = query_params.get(CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER, None) - else: - content_encoding = request.headers.get(CONNECT_UNARY_HEADER_COMPRESSION, None) - accept_encoding = request.headers.get(CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION, None) - else: - content_encoding = request.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, None) - accept_encoding = request.headers.get(CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, None) - - request_compression, response_compression, error = negotiate_compression( - self.params.compressions, content_encoding, accept_encoding - ) - - if error is None: - required = self.params.require_connect_protocol_header and self.params.spec.stream_type == StreamType.Unary - error = connect_check_protocol_version(request, required) - - if HTTPMethod(request.method) == HTTPMethod.GET: - encoding = query_params.get(CONNECT_UNARY_ENCODING_QUERY_PARAMETER, "") - message = query_params.get(CONNECT_UNARY_MESSAGE_QUERY_PARAMETER, "") - if error is None and encoding == "": - error = ConnectError( - f"missing {CONNECT_UNARY_ENCODING_QUERY_PARAMETER} parameter", - Code.INVALID_ARGUMENT, - ) - if error is None and message == "": - error = ConnectError( - f"missing {CONNECT_UNARY_MESSAGE_QUERY_PARAMETER} parameter", - Code.INVALID_ARGUMENT, - ) - - if query_params.get(CONNECT_UNARY_BASE64_QUERY_PARAMETER) == "1": - message_unquoted = unquote(message) - decoded = base64.urlsafe_b64decode(message_unquoted + "=" * (-len(message_unquoted) % 4)) - else: - decoded = message.encode() - - request_stream = aiterate([decoded]) - codec_name = encoding - content_type = connect_content_type_from_codec_name(self.params.spec.stream_type, codec_name) - else: - request_stream = request.stream() - content_type = request.headers.get(HEADER_CONTENT_TYPE, "") - codec_name = connect_codec_from_content_type(self.params.spec.stream_type, content_type) - - codec = self.params.codecs.get(codec_name) - if error is None and codec is None: - error = ConnectError( - f"invalid message encoding: {codec_name}", - Code.INVALID_ARGUMENT, - ) - - response_headers[HEADER_CONTENT_TYPE] = content_type - - if self.params.spec.stream_type == StreamType.Unary: - response_headers[CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION] = ( - f"{', '.join(c.name for c in self.params.compressions)}" - ) - else: - if response_compression and response_compression.name != COMPRESSION_IDENTITY: - response_headers[CONNECT_STREAMING_HEADER_COMPRESSION] = response_compression.name - - response_headers[CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION] = ( - f"{', '.join(c.name for c in self.params.compressions)}" - ) - - peer = Peer( - address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, - protocol=PROTOCOL_CONNECT, - query=request.query_params, - ) - - conn: StreamingHandlerConn - if self.params.spec.stream_type == StreamType.Unary: - conn = ConnectUnaryHandlerConn( - writer=writer, - request=request, - peer=peer, - spec=self.params.spec, - marshaler=ConnectUnaryMarshaler( - codec=codec, - compress_min_bytes=self.params.compress_min_bytes, - send_max_bytes=self.params.send_max_bytes, - compression=response_compression, - headers=response_headers, - ), - unmarshaler=ConnectUnaryUnmarshaler( - stream=request_stream, - codec=codec, - compression=request_compression, - read_max_bytes=self.params.read_max_bytes, - ), - request_headers=Headers(request.headers, encoding="latin-1"), - response_headers=response_headers, - response_trailers=response_trailers, - ) - - else: - conn = ConnectStreamingHandlerConn( - writer=writer, - request=request, - peer=peer, - spec=self.params.spec, - marshaler=ConnectStreamingMarshaler( - codec=codec, - compress_min_bytes=self.params.compress_min_bytes, - send_max_bytes=self.params.send_max_bytes, - compression=response_compression, - ), - unmarshaler=ConnectStreamingUnmarshaler( - stream=request.stream(), - codec=codec, - compression=request_compression, - read_max_bytes=self.params.read_max_bytes, - ), - request_headers=Headers(request.headers, encoding="latin-1"), - response_headers=response_headers, - response_trailers=response_trailers, - ) - - if error: - await conn.send_error(error) - return None - - return conn - - -class ProtocolConnect(Protocol): - """ProtocolConnect is a class that implements the Protocol interface for handling connection protocols.""" - - def handler(self, params: ProtocolHandlerParams) -> ConnectHandler: - """Handle the creation of a ConnectHandler based on the provided ProtocolHandlerParams. - - Args: - params (ProtocolHandlerParams): The parameters required to create the ConnectHandler. - - Returns: - ConnectHandler: An instance of ConnectHandler configured with the appropriate methods and content types. - - """ - methods = [HTTPMethod.POST] - - if params.spec.stream_type == StreamType.Unary and params.idempotency_level == IdempotencyLevel.NO_SIDE_EFFECTS: - methods.append(HTTPMethod.GET) - - content_types: list[str] = [] - for name in params.codecs.names(): - if params.spec.stream_type == StreamType.Unary: - content_types.append(CONNECT_UNARY_CONTENT_TYPE_PREFIX + name) - continue - - content_types.append(CONNECT_STREAMING_CONTENT_TYPE_PREFIX + name) - - return ConnectHandler(params, methods=methods, accept=content_types) - - def client(self, params: ProtocolClientParams) -> ProtocolClient: - """Create and returns a ConnectClient instance. - - Args: - params (ProtocolClientParams): The parameters required to initialize the client. - - Returns: - ProtocolClient: An instance of ConnectClient. - - """ - return ConnectClient(params) - - -class ConnectUnaryUnmarshaler: - """A class to handle the unmarshaling of data using a specified codec. - - Attributes: - codec (Codec): The codec used for unmarshaling the data. - body (bytes): The raw data to be unmarshaled. - read_max_bytes (int): The maximum number of bytes to read. - compression (Compression | None): The compression method to use, if any. - - """ - - codec: Codec | None - read_max_bytes: int - compression: Compression | None - stream: AsyncIterable[bytes] | None - - def __init__( - self, - codec: Codec | None, - read_max_bytes: int, - compression: Compression | None = None, - stream: AsyncIterable[bytes] | None = None, - ) -> None: - """Initialize the ProtocolConnect object. - - Args: - stream (AsyncIterable[bytes] | None): The stream of bytes to be unmarshaled. - codec (Codec): The codec used for encoding/decoding the message. - read_max_bytes (int): The maximum number of bytes to read. - compression (Compression | None): The compression method to use, if any. - - """ - self.codec = codec - self.read_max_bytes = read_max_bytes - self.compression = compression - self.stream = stream - - async def unmarshal(self, message: Any) -> Any: - """Asynchronously unmarshals a given message using the provided unmarshal function and codec. - - Args: - message (Any): The message to be unmarshaled. - - Returns: - Any: The result of the unmarshaling process. - - """ - if self.codec is None: - raise ConnectError("codec is not set", Code.INTERNAL) - - return await self.unmarshal_func(message, self.codec.unmarshal) - - async def unmarshal_func(self, message: Any, func: Callable[[bytes, Any], Any]) -> Any: - """Asynchronously unmarshals a message using the provided function. - - This function reads data from the stream in chunks, checks if the total - bytes read exceed the maximum allowed bytes, and optionally decompresses - the data. It then uses the provided function to unmarshal the data into - the desired format. - - Args: - message (Any): The message to be unmarshaled. - func (Callable[[bytes, Any], Any]): A function that takes the raw bytes - and the message, and returns the unmarshaled object. - - Returns: - Any: The unmarshaled object. - - Raises: - ConnectError: If the stream is not set, if the message size exceeds the - maximum allowed bytes, or if there is an error during unmarshaling. - - """ - if self.stream is None: - raise ConnectError("stream is not set", Code.INTERNAL) - - chunks: list[bytes] = [] - bytes_read = 0 - try: - async for chunk in self.stream: - chunk_size = len(chunk) - bytes_read += chunk_size - if self.read_max_bytes > 0 and bytes_read > self.read_max_bytes: - raise ConnectError( - f"message size {bytes_read} is larger than configured max {self.read_max_bytes}", - Code.RESOURCE_EXHAUSTED, - ) - - chunks.append(chunk) - - data = b"".join(chunks) - - if len(data) > 0 and self.compression: - data = self.compression.decompress(data, self.read_max_bytes) - - try: - obj = func(data, message) - except Exception as e: - raise ConnectError( - f"unmarshal message: {str(e)}", - Code.INVALID_ARGUMENT, - ) from e - finally: - await self.aclose() - - return obj - - async def aclose(self) -> None: - """Asynchronously close the stream if it is set. - - This method is intended to be called when the stream is no longer needed - to release any associated resources. - - """ - aclose = get_acallable_attribute(self.stream, "aclose") - if aclose: - await aclose() - - -class ConnectUnaryMarshaler: - """ConnectUnaryMarshaler is responsible for serializing and optionally compressing messages. - - Attributes: - codec (Codec): The codec used for serializing messages. - compression (Compression | None): The compression method used for compressing messages, if any. - compress_min_bytes (int): The minimum size in bytes for a message to be compressed. - send_max_bytes (int): The maximum allowed size in bytes for a message to be sent. - headers (Headers | Headers): The headers to be included in the message. - - """ - - codec: Codec | None - compression: Compression | None - compress_min_bytes: int - send_max_bytes: int - headers: Headers - - def __init__( - self, - codec: Codec | None, - compression: Compression | None, - compress_min_bytes: int, - send_max_bytes: int, - headers: Headers, - ) -> None: - """Initialize the protocol connection. - - Args: - codec (Codec): The codec to be used for encoding/decoding. - compression (Compression | None): The compression method to be used, or None if no compression. - compress_min_bytes (int): The minimum number of bytes before compression is applied. - send_max_bytes (int): The maximum number of bytes to send in a single message. - headers (Headers): The headers to be included in the connection. - - Returns: - None - - """ - self.codec = codec - self.compression = compression - self.compress_min_bytes = compress_min_bytes - self.send_max_bytes = send_max_bytes - self.headers = headers - - def marshal(self, message: Any) -> bytes: - """Marshals a message into bytes, optionally compressing it if it exceeds a certain size. - - Args: - message (Any): The message to be marshaled. - - Returns: - bytes: The marshaled (and possibly compressed) message. - - Raises: - ConnectError: If there is an error during marshaling or if the message size exceeds the allowed limit. - - """ - if self.codec is None: - raise ConnectError("codec is not set", Code.INTERNAL) - - try: - data = self.codec.marshal(message) - except Exception as e: - raise ConnectError(f"marshal message: {str(e)}", Code.INTERNAL) from e - - if len(data) < self.compress_min_bytes or self.compression is None: - if self.send_max_bytes > 0 and len(data) > self.send_max_bytes: - raise ConnectError( - f"message size {len(data)} exceeds send_max_bytes {self.send_max_bytes}", Code.RESOURCE_EXHAUSTED - ) - - return data - - data = self.compression.compress(data) - - if self.send_max_bytes > 0 and len(data) > self.send_max_bytes: - raise ConnectError( - f"compressed message size {len(data)} exceeds send_max_bytes {self.send_max_bytes}", - Code.RESOURCE_EXHAUSTED, - ) - - self.headers[CONNECT_UNARY_HEADER_COMPRESSION] = self.compression.name - - return data - - -class ConnectUnaryHandlerConn(StreamingHandlerConn): - """ConnectUnaryHandlerConn is a handler connection class for unary RPCs in the Connect protocol. - - Attributes: - request (Request): The incoming request object. - marshaler (ConnectUnaryMarshaler): An instance of ConnectUnaryMarshaler used to marshal messages. - unmarshaler (ConnectUnaryUnmarshaler): An instance of ConnectUnaryUnmarshaler used to unmarshal messages. - headers (Headers): The headers for the response. - - """ - - writer: ServerResponseWriter - request: Request - _peer: Peer - _spec: Spec - marshaler: ConnectUnaryMarshaler - unmarshaler: ConnectUnaryUnmarshaler - _request_headers: Headers - _response_headers: Headers - _response_trailers: Headers - - def __init__( - self, - writer: ServerResponseWriter, - request: Request, - peer: Peer, - spec: Spec, - marshaler: ConnectUnaryMarshaler, - unmarshaler: ConnectUnaryUnmarshaler, - request_headers: Headers, - response_headers: Headers, - response_trailers: Headers | None = None, - ) -> None: - """Initialize the protocol connection. - - Args: - writer (ServerResponseWriter): The writer to send the response. - request (Request): The incoming request object. - peer (Peer): The peer information. - spec (Spec): The specification object. - marshaler (ConnectUnaryMarshaler): The marshaler to serialize data. - unmarshaler (ConnectUnaryUnmarshaler): The unmarshaler to deserialize data. - request_headers (Headers): The headers for the request. - response_headers (Headers): The headers for the response. - response_trailers (Headers, optional): The trailers for the response. - - """ - self.writer = writer - self.request = request - self._peer = peer - self._spec = spec - self.marshaler = marshaler - self.unmarshaler = unmarshaler - self._request_headers = request_headers - self._response_headers = response_headers - self._response_trailers = response_trailers if response_trailers is not None else Headers() - - def parse_timeout(self) -> float | None: - """Parse the timeout value.""" - try: - timeout = self.request.headers.get(CONNECT_HEADER_TIMEOUT) - if timeout is None: - return None - - timeout_ms = int(timeout) - except ValueError as e: - raise ConnectError(f"parse timeout: {str(e)}", Code.INVALID_ARGUMENT) from e - - return timeout_ms / 1000 - - @property - def spec(self) -> Spec: - """Return the specification object. - - Returns: - Spec: The specification object. - - """ - return self._spec - - @property - def peer(self) -> Peer: - """Return the peer associated with this instance. - - :return: The peer associated with this instance. - :rtype: Peer - """ - return self._peer - - async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: - """Receives and unmarshals a message into an object. - - Args: - message (Any): The message to be unmarshaled. - - Returns: - AsyncIterator[Any]: An async iterator yielding the unmarshaled object. - - """ - yield await self.unmarshaler.unmarshal(message) - - def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message, unmarshals it, and returns the resulting object. - - Args: - message (Any): The message to be unmarshaled. - - Returns: - AsyncIterator[Any]: An async iterator yielding the unmarshaled object. - - """ - return self._receive_messages(message) - - @property - def request_headers(self) -> Headers: - """Retrieve the headers from the request. - - Returns: - Mapping[str, str]: A dictionary-like object containing the request headers. - - """ - return self._request_headers - - async def send(self, messages: AsyncIterable[Any]) -> None: - """Send message(s) by marshaling them into bytes. - - Args: - messages (AsyncIterable[Any]): The message(s) to be sent. For unary operations, - this should be an iterable with a single item. - - Returns: - None - - """ - self.merge_response_trailers() - - message = await ensure_single(messages) - - data = self.marshaler.marshal(message) - await self.writer.write(Response(data, HTTPStatus.OK, self.response_headers)) - - @property - def response_headers(self) -> Headers: - """Retrieve the response headers. - - Returns: - Any: The response headers. - - """ - return self._response_headers - - @property - def response_trailers(self) -> Headers: - """Handle response trailers. - - This method is intended to be overridden in subclasses to provide - specific functionality for processing response trailers. - - Returns: - Any: The processed response trailer data. - - """ - return self._response_trailers - - def get_http_method(self) -> HTTPMethod: - """Retrieve the HTTP method from the request. - - Returns: - HTTPMethod: The HTTP method from the request. - - """ - return HTTPMethod(self.request.method) - - async def send_error(self, error: ConnectError) -> None: - """Send an error response. - - This method updates the response headers with the error metadata, - sets the response trailers, converts the error code to an HTTP status code, - serializes the error to JSON, and writes the response. - - Args: - error (ConnectError): The error to be sent in the response. - - Returns: - None - - """ - if not error.wire_error: - self.response_headers.update(exclude_protocol_headers(error.metadata)) - - self.merge_response_trailers() - - status_code = connect_code_to_http(error.code) - self.response_headers[HEADER_CONTENT_TYPE] = CONNECT_UNARY_CONTENT_TYPE_JSON - - body = error_to_json_bytes(error) - - await self.writer.write(Response(content=body, headers=self.response_headers, status_code=status_code)) - - def merge_response_trailers(self) -> None: - """Merge response trailers into the response headers. - - This method iterates through the `_response_trailers` dictionary and adds - each trailer key-value pair to the `_response_headers` dictionary, - prefixing the trailer keys with `CONNECT_UNARY_TRAILER_PREFIX`. - - Returns: - None - - """ - for key, value in self._response_trailers.items(): - self._response_headers[CONNECT_UNARY_TRAILER_PREFIX + key] = value - - -class ConnectClient(ProtocolClient): - """ConnectClient is a client for handling connections using the Connect protocol. - - Attributes: - params (ProtocolClientParams): Parameters for the protocol client. - _peer (Peer): The peer object representing the connection endpoint. - - """ - - params: ProtocolClientParams - _peer: Peer - - def __init__(self, params: ProtocolClientParams) -> None: - """Initialize the ProtocolConnect instance with the given parameters. - - Args: - params (ProtocolClientParams): The parameters required to initialize the ProtocolConnect instance. - - """ - self.params = params - self._peer = Peer( - address=Address(host=params.url.host or "", port=params.url.port or 80), - protocol=PROTOCOL_CONNECT, - query={}, - ) - - @property - def peer(self) -> Peer: - """Return the peer associated with this instance. - - :return: The peer associated with this instance. - :rtype: Peer - """ - return self._peer - - def write_request_headers(self, stream_type: StreamType, headers: Headers) -> None: - """Write the necessary request headers to the provided headers dictionary. - - This method ensures that the headers dictionary contains the required headers - for a request, including user agent, protocol version, content type, and - optionally, compression settings. - - Args: - stream_type (StreamType): The type of stream for the request. - headers (Headers): The dictionary of headers to be updated. - - Returns: - None - - """ - if headers.get(HEADER_USER_AGENT, None) is None: - headers[HEADER_USER_AGENT] = DEFAULT_CONNECT_USER_AGENT - - headers[CONNECT_HEADER_PROTOCOL_VERSION] = CONNECT_PROTOCOL_VERSION - headers[HEADER_CONTENT_TYPE] = connect_content_type_from_codec_name(stream_type, self.params.codec.name) - - accept_compression_header = CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION - if stream_type != StreamType.Unary: - headers[CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION] = COMPRESSION_IDENTITY - accept_compression_header = CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION - if self.params.compression_name and self.params.compression_name != COMPRESSION_IDENTITY: - headers[CONNECT_STREAMING_HEADER_COMPRESSION] = self.params.compression_name - - if self.params.compressions: - headers[accept_compression_header] = ", ".join(c.name for c in self.params.compressions) - - def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: - """Establish a unary client connection with the given specifications and headers. - - Args: - spec (Spec): The specification for the connection. - headers (Headers): The headers to be included in the request. - - Returns: - UnaryClientConn: The established unary client connection. - - """ - conn: StreamingClientConn - if spec.stream_type == StreamType.Unary: - conn = ConnectUnaryClientConn( - session=self.params.session, - spec=spec, - peer=self.peer, - url=self.params.url, - compressions=self.params.compressions, - request_headers=headers, - marshaler=ConnectUnaryRequestMarshaler( - codec=self.params.codec, - compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), - compress_min_bytes=self.params.compress_min_bytes, - send_max_bytes=self.params.send_max_bytes, - headers=headers, - ), - unmarshaler=ConnectUnaryUnmarshaler( - codec=self.params.codec, - read_max_bytes=self.params.read_max_bytes, - ), - ) - if spec.idempotency_level == IdempotencyLevel.NO_SIDE_EFFECTS: - conn.marshaler.enable_get = self.params.enable_get - conn.marshaler.url = self.params.url - if isinstance(self.params.codec, StableCodec): - conn.marshaler.stable_codec = self.params.codec - else: - conn = ConnectStreamingClientConn( - session=self.params.session, - spec=spec, - peer=self.peer, - url=self.params.url, - codec=self.params.codec, - compressions=self.params.compressions, - request_headers=headers, - marshaler=ConnectStreamingMarshaler( - codec=self.params.codec, - compress_min_bytes=self.params.compress_min_bytes, - send_max_bytes=self.params.send_max_bytes, - compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), - ), - unmarshaler=ConnectStreamingUnmarshaler( - codec=self.params.codec, - read_max_bytes=self.params.read_max_bytes, - ), - ) - - return conn - - -class ConnectUnaryRequestMarshaler(ConnectUnaryMarshaler): - """ConnectUnaryRequestMarshaler is responsible for marshaling unary request messages for the Connect protocol, with support for GET requests and stable codecs. - - This class extends ConnectUnaryMarshaler to provide additional functionality for handling GET requests, - including marshaling messages using a stable codec, enforcing message size limits, and optionally compressing - messages when necessary. It also manages the construction of GET URLs with appropriate query parameters and - headers for the Connect protocol. - - Attributes: - enable_get (bool): Flag indicating whether GET requests are enabled. - stable_codec (StableCodec | None): The codec used for stable marshaling, if available. - url (URL | None): The URL to use for the request. - - """ - - enable_get: bool - stable_codec: StableCodec | None - url: URL | None - - def __init__( - self, - codec: Codec | None, - compression: Compression | None, - compress_min_bytes: int, - send_max_bytes: int, - headers: Headers, - enable_get: bool = False, - stable_codec: StableCodec | None = None, - url: URL | None = None, - ) -> None: - """Initialize the protocol connection with the specified configuration. - - Args: - codec (Codec | None): The codec to use for encoding/decoding messages, or None. - compression (Compression | None): The compression algorithm to use, or None. - compress_min_bytes (int): Minimum number of bytes before compression is applied. - send_max_bytes (int): Maximum number of bytes allowed per send operation. - headers (Headers): Headers to include in each request. - enable_get (bool, optional): Whether to enable GET requests. Defaults to False. - stable_codec (StableCodec | None, optional): An optional stable codec for message encoding/decoding. Defaults to None. - url (URL | None, optional): The URL endpoint for the connection. Defaults to None. - - Returns: - None - - """ - super().__init__(codec, compression, compress_min_bytes, send_max_bytes, headers) - self.enable_get = enable_get - self.stable_codec = stable_codec - self.url = url - - def marshal(self, message: Any) -> bytes: - """Marshal a message into bytes. - - If `enable_get` is True and `stable_codec` is None, raises a `ConnectError` - indicating that the codec does not support stable marshal and cannot use get. - Otherwise, if `enable_get` is True and `stable_codec` is not None, marshals - the message using the `marshal_with_get` method. - - If `enable_get` is False, marshals the message using the `. - - Args: - message (Any): The message to be marshaled. - - Returns: - bytes: The marshaled message in bytes. - - Raises: - ConnectError: If `enable_get` is True and `stable_codec` is None. - - """ - if self.enable_get: - if self.codec is None: - raise ConnectError("codec is not set", Code.INTERNAL) - - if self.stable_codec is None: - raise ConnectError( - f"codec {self.codec.name} doesn't support stable marshal; can't use get", - Code.INTERNAL, - ) - else: - return self.marshal_with_get(message) - - return super().marshal(message) - - def marshal_with_get(self, message: Any) -> bytes: - """Marshals the given message and sends it using a GET request. - - This method first marshals the message using the stable codec. If the marshaled - data exceeds the maximum allowed size (`send_max_bytes`) and compression is not - enabled, it raises a `ConnectError`. If the data size is within the limit, it - builds the GET URL and sends the data. - - If the data size exceeds the limit and compression is enabled, it compresses - the data and checks the size again. If the compressed data still exceeds the - limit, it raises a `ConnectError`. Otherwise, it builds the GET URL with the - compressed data and sends it. - - Args: - message (Any): The message to be marshaled and sent. - - Returns: - bytes: The marshaled (and possibly compressed) data. - - Raises: - ConnectError: If the data size exceeds the maximum allowed size and compression - is not enabled, or if the compressed data size still exceeds the - limit. - - """ - if self.stable_codec is None: - raise ConnectError("stable_codec is not set", Code.INTERNAL) - - data = self.stable_codec.marshal_stable(message) - - is_too_big = self.send_max_bytes > 0 and len(data) > self.send_max_bytes - if is_too_big and not self.compression: - raise ConnectError( - f"message size {len(data)} exceeds sendMaxBytes {self.send_max_bytes}: enabling request compression may help", - Code.RESOURCE_EXHAUSTED, - ) - - if not is_too_big: - url = self._build_get_url(data, False) - - self._write_with_get(url) - return data - - if self.compression: - data = self.compression.compress(data) - - if self.send_max_bytes > 0 and len(data) > self.send_max_bytes: - raise ConnectError( - f"compressed message size {len(data)} exceeds send_max_bytes {self.send_max_bytes}", - Code.RESOURCE_EXHAUSTED, - ) - - url = self._build_get_url(data, True) - self._write_with_get(url) - - return data - - def _build_get_url(self, data: bytes, compressed: bool) -> URL: - if self.url is None or self.stable_codec is None: - raise ConnectError("url or stable_codec is not set", Code.INTERNAL) - - if self.codec is None: - raise ConnectError("codec is not set", Code.INTERNAL) - - url = self.url - url = url.update_query({ - CONNECT_UNARY_CONNECT_QUERY_PARAMETER: CONNECT_UNARY_CONNECT_QUERY_VALUE, - CONNECT_UNARY_ENCODING_QUERY_PARAMETER: self.codec.name, - }) - if self.stable_codec.is_binary() or compressed: - url = url.update_query({ - CONNECT_UNARY_MESSAGE_QUERY_PARAMETER: base64.urlsafe_b64encode(data).rstrip(b"=").decode("utf-8"), - CONNECT_UNARY_BASE64_QUERY_PARAMETER: "1", - }) - else: - url = url.update_query({ - CONNECT_UNARY_MESSAGE_QUERY_PARAMETER: data.decode("utf-8"), - }) - - if compressed: - if not self.compression: - raise ConnectError( - "compression must be set for compressed message", - Code.INTERNAL, - ) - - url = url.update_query({CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER: self.compression.name}) - - return url - - def _write_with_get(self, url: URL) -> None: - with contextlib.suppress(Exception): - del self.headers[CONNECT_HEADER_PROTOCOL_VERSION] - del self.headers[HEADER_CONTENT_TYPE] - del self.headers[HEADER_CONTENT_ENCODING] - del self.headers[HEADER_CONTENT_LENGTH] - - self.url = url - - -class ConnectStreamingMarshaler(EnvelopeWriter): - """A class responsible for marshaling messages with optional compression. - - Attributes: - codec (Codec): The codec used for marshaling messages. - compression (Compression | None): The compression method used for compressing messages, if any. - - """ - - codec: Codec | None - compress_min_bytes: int - send_max_bytes: int - compression: Compression | None - - def __init__( - self, codec: Codec | None, compression: Compression | None, compress_min_bytes: int, send_max_bytes: int - ) -> None: - """Initialize the ProtocolConnect instance. - - Args: - codec (Codec): The codec to be used for encoding and decoding. - compression (Compression | None): The compression method to be used, or None if no compression is to be applied. - compress_min_bytes (int): The minimum number of bytes before compression is applied. - send_max_bytes (int): The maximum number of bytes that can be sent in a single message. - - """ - self.codec = codec - self.compress_min_bytes = compress_min_bytes - self.send_max_bytes = send_max_bytes - self.compression = compression - - def marshal_end_stream(self, error: ConnectError | None, response_trailers: Headers) -> bytes: - """Serialize the end-of-stream message with optional error and response trailers into a bytes envelope. - - Args: - error (ConnectError | None): An optional error object to include in the end-of-stream message. - response_trailers (Headers): Headers to include as response trailers. - - Returns: - bytes: The serialized envelope containing the end-of-stream message. - - """ - json_obj = end_stream_to_json(error, response_trailers) - json_str = json.dumps(json_obj) - - env = self.write_envelope(json_str.encode(), EnvelopeFlags.end_stream) - - return env.encode() - - -class ConnectStreamingUnmarshaler(EnvelopeReader): - """A class to handle the unmarshaling of streaming data. - - Attributes: - codec (Codec): The codec used for unmarshaling data. - compression (Compression | None): The compression method used, if any. - stream (AsyncIterable[bytes] | None): The asynchronous byte stream to read from. - buffer (bytes): The buffer to store incoming data chunks. - - """ - - _end_stream_error: ConnectError | None - _trailers: Headers - - def __init__( - self, - codec: Codec | None, - read_max_bytes: int, - stream: AsyncIterable[bytes] | None = None, - compression: Compression | None = None, - ) -> None: - """Initialize the protocol connection. - - Args: - codec (Codec): The codec to use for encoding and decoding data. - read_max_bytes (int): The maximum number of bytes to read from the stream. - stream (AsyncIterable[bytes] | None, optional): The asynchronous byte stream to read from. Defaults to None. - compression (Compression | None, optional): The compression method to use. Defaults to None. - - """ - super().__init__(codec, read_max_bytes, stream, compression) - self._end_stream_error = None - self._trailers = Headers() - - async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: - """Asynchronously unmarshals messages from the stream. - - Args: - message (Any): The message type to unmarshal. - - Yields: - Any: The unmarshaled message object. - - Raises: - ConnectError: If the stream is not set, if there is an error in the - unmarshaling process, or if there is a protocol error. - - """ - async for obj, end in super().unmarshal(message): - if self.last: - error, trailers = end_stream_from_bytes(self.last.data) - self._end_stream_error = error - self._trailers = trailers - - yield obj, end - - @property - def trailers(self) -> Headers: - """Return the trailers headers. - - Trailers are additional headers sent after the body of the message. - - Returns: - Headers: The trailers headers. - - """ - return self._trailers - - @property - def end_stream_error(self) -> ConnectError | None: - """Return the error that occurred at the end of the stream, if any. - - Returns: - ConnectError | None: The error that occurred at the end of the stream, - or None if no error occurred. - - """ - return self._end_stream_error - - -class ConnectStreamingHandlerConn(StreamingHandlerConn): - """ConnectStreamingHandlerConn is a class that handles streaming connections for the Connect protocol. - - Attributes: - writer (ServerResponseWriter): The writer used to send responses. - request (Request): The incoming request object. - _peer (Peer): The peer associated with this connection. - _spec (Spec): The specification object. - marshaler (ConnectStreamingMarshaler): The marshaler used to serialize messages. - unmarshaler (ConnectStreamingUnmarshaler): The unmarshaler used to deserialize messages. - _request_headers (Headers): The headers from the request. - _response_headers (Headers): The headers for the response. - _response_trailers (Headers): The trailers for the response. - - """ - - writer: ServerResponseWriter - request: Request - _peer: Peer - _spec: Spec - marshaler: ConnectStreamingMarshaler - unmarshaler: ConnectStreamingUnmarshaler - _request_headers: Headers - _response_headers: Headers - _response_trailers: Headers - - def __init__( - self, - writer: ServerResponseWriter, - request: Request, - peer: Peer, - spec: Spec, - marshaler: ConnectStreamingMarshaler, - unmarshaler: ConnectStreamingUnmarshaler, - request_headers: Headers, - response_headers: Headers, - response_trailers: Headers | None = None, - ) -> None: - """Initialize the protocol connection. - - Args: - writer (ServerResponseWriter): The writer for server responses. - request (Request): The request object. - peer (Peer): The peer information. - spec (Spec): The specification details. - marshaler (ConnectStreamingMarshaler): The marshaler for streaming. - unmarshaler (ConnectStreamingUnmarshaler): The unmarshaler for streaming. - request_headers (Headers): The headers for the request. - response_headers (Headers): The headers for the response. - response_trailers (Headers, optional): The trailers for the response. Defaults to None. - - """ - self.writer = writer - self.request = request - self._peer = peer - self._spec = spec - self.marshaler = marshaler - self.unmarshaler = unmarshaler - self._request_headers = request_headers - self._response_headers = response_headers - self._response_trailers = response_trailers if response_trailers is not None else Headers() - - def parse_timeout(self) -> float | None: - """Parse the timeout value.""" - try: - timeout = self.request.headers.get(CONNECT_HEADER_TIMEOUT) - if timeout is None: - return None - - timeout_ms = int(timeout) - except ValueError as e: - raise ConnectError(f"parse timeout: {str(e)}", Code.INVALID_ARGUMENT) from e - - return timeout_ms / 1000 - - @property - def spec(self) -> Spec: - """Return the specification object. - - Returns: - Spec: The specification object. - - """ - return self._spec - - @property - def peer(self) -> Peer: - """Return the peer associated with this instance. - - :return: The peer associated with this instance. - :rtype: Peer - """ - return self._peer - - async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: - """Asynchronously receives a message and yields unmarshaled objects. - - This method unmarshals the received message and yields each - unmarshaled object one by one as an asynchronous iterator. - - Args: - message (Any): The message to unmarshal. - - Returns: - AsyncIterator[Any]: An asynchronous iterator yielding unmarshaled objects. - - Yields: - Any: Each unmarshaled object from the message. - - """ - async for obj, _ in self.unmarshaler.unmarshal(message): - yield obj - - def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message and returns an asynchronous content stream. - - This method processes the incoming message through the receive_message method - and wraps the result in an AsyncContentStream with the appropriate stream type. - - Args: - message (Any): The message to be processed. - - Returns: - AsyncContentStream[Any]: An asynchronous stream of content based on the - processed message, configured with the specification's stream type. - - """ - return self._receive_messages(message) - - @property - def request_headers(self) -> Headers: - """Retrieve the headers from the request. - - Returns: - Mapping[str, str]: A dictionary-like object containing the request headers. - - """ - return self._request_headers - - async def _send_messages(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: - """Create an async iterator that marshals messages with error handling. - - Args: - messages (AsyncIterable[Any]): Messages to marshal - - Returns: - AsyncIterator[bytes]: Marshaled bytes with end stream message - - Yields: - bytes: Each marshaled message followed by an end stream message - - """ - error: ConnectError | None = None - try: - async for message in self.marshaler.marshal(messages): - yield message - except Exception as e: - error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) - finally: - body = self.marshaler.marshal_end_stream(error, self.response_trailers) - yield body - - async def send(self, messages: AsyncIterable[Any]) -> None: - """Send a stream of messages asynchronously. - - This method marshals the provided messages and sends them using the writer. - If an error occurs during the marshaling process, it captures the error, - converts it to a JSON object, and sends it as the final message in the stream. - - Args: - messages (AsyncIterable[Any]): An asynchronous iterable of messages to be sent. - - Returns: - None - - Raises: - ConnectError: If an error occurs during the marshaling process. - - """ - await self.writer.write( - StreamingResponse( - content=self._send_messages(messages), - headers=self.response_headers, - status_code=200, - ) - ) - - @property - def response_headers(self) -> Headers: - """Retrieve the response headers. - - Returns: - Any: The response headers. - - """ - return self._response_headers - - @property - def response_trailers(self) -> Headers: - """Handle response trailers. - - This method is intended to be overridden in subclasses to provide - specific functionality for processing response trailers. - - Returns: - Any: The processed response trailer data. - - """ - return self._response_trailers - - async def send_error(self, error: ConnectError) -> None: - """Send an error response in the form of a JSON object. - - Args: - error (ConnectError): The error object to be sent. - - Returns: - None - - """ - body = self.marshaler.marshal_end_stream(error, self.response_trailers) - - await self.writer.write( - StreamingResponse(content=aiterate([body]), headers=self.response_headers, status_code=200) - ) - - -EventHook = Callable[..., Any] - - -class ConnectStreamingClientConn(StreamingClientConn): - """ConnectStreamingClientConn is a class that manages a streaming client connection using the Connect protocol.""" - - _spec: Spec - _peer: Peer - url: URL - codec: Codec - compressions: list[Compression] - marshaler: ConnectStreamingMarshaler - unmarshaler: ConnectStreamingUnmarshaler - response_content: bytes | None - _response_headers: Headers - _response_trailers: Headers - _request_headers: Headers - - def __init__( - self, - session: AsyncClientSession, - spec: Spec, - peer: Peer, - url: URL, - codec: Codec, - compressions: list[Compression], - request_headers: Headers, - marshaler: ConnectStreamingMarshaler, - unmarshaler: ConnectStreamingUnmarshaler, - event_hooks: None | (Mapping[str, list[EventHook]]) = None, - ) -> None: - """Initialize a new instance of the class. - - Args: - session (AsyncClientSession): The session object for the connection. - spec (Spec): The specification object. - peer (Peer): The peer object. - url (URL): The URL for the connection. - codec (Codec): The codec to be used for encoding and decoding. - compressions (list[Compression]): List of compression methods. - request_headers (Headers): The headers for the request. - marshaler (ConnectStreamingMarshaler): The marshaler for streaming. - unmarshaler (ConnectStreamingUnmarshaler): The unmarshaler for streaming. - event_hooks (None | Mapping[str, list[EventHook]], optional): Event hooks for request and response. Defaults to None. - - Returns: - None - - """ - event_hooks = {} if event_hooks is None else event_hooks - - self.session = session - self._spec = spec - self._peer = peer - self.url = url - self.codec = codec - self.compressions = compressions - self.marshaler = marshaler - self.unmarshaler = unmarshaler - self.response_content = None - self._response_headers = Headers() - self._response_trailers = Headers() - self._request_headers = request_headers - self._event_hooks = { - "request": list(event_hooks.get("request", [])), - "response": list(event_hooks.get("response", [])), - } - - @property - def spec(self) -> Spec: - """Return the specification of the protocol. - - Returns: - Spec: The specification object of the protocol. - - """ - return self._spec - - @property - def peer(self) -> Peer: - """Return the peer object associated with this instance. - - :return: The peer object. - :rtype: Peer - """ - return self._peer - - @property - def request_headers(self) -> Headers: - """Retrieve the request headers. - - Returns: - Headers: A dictionary-like object containing the request headers. - - """ - return self._request_headers - - @property - def response_headers(self) -> Headers: - """Return the response headers. - - Returns: - Headers: A dictionary-like object containing the response headers. - - """ - return self._response_headers - - @property - def response_trailers(self) -> Headers: - """Return the response trailers. - - Response trailers are additional headers sent after the response body. - - Returns: - Headers: A dictionary containing the response trailers. - - """ - return self._response_trailers - - def on_request_send(self, fn: EventHook) -> None: - """Register a callback function to be called when a request is sent. - - Args: - fn (EventHook): The callback function to be registered. This function - will be called with the request details when a request - is sent. - - """ - self._event_hooks["request"].append(fn) - - async def receive(self, message: Any, abort_event: asyncio.Event | None = None) -> AsyncIterator[Any]: - """Asynchronously receives and processes a message. - - Args: - message (Any): The message to be processed. - abort_event (asyncio.Event | None): Event to signal abortion of the operation. - - Yields: - Any: Objects obtained from unmarshaling the message. - - Raises: - ConnectError: If stream is malformed or aborted. - - """ - end_stream_received = False - - async for obj, end in self.unmarshaler.unmarshal(message): - if abort_event and abort_event.is_set(): - raise ConnectError("receive operation aborted", Code.CANCELED) - - if end: - if end_stream_received: - raise ConnectError("received extra end stream message", Code.INVALID_ARGUMENT) - - end_stream_received = True - error = self.unmarshaler.end_stream_error - if error: - for key, value in self.response_headers.items(): - error.metadata[key] = value - error.metadata.update(self.unmarshaler.trailers.copy()) - raise error - - for key, value in self.unmarshaler.trailers.items(): - self.response_trailers[key] = value - - continue - - if end_stream_received: - raise ConnectError("received message after end stream", Code.INVALID_ARGUMENT) - - yield obj - - if not end_stream_received: - raise ConnectError("missing end stream message", Code.INVALID_ARGUMENT) - - async def send( - self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None - ) -> None: - """Send an asynchronous HTTP POST request with the given messages and handle the response. - - Args: - messages (AsyncIterable[Any]): An asynchronous iterable of messages to be sent. - timeout (float | None): Optional timeout value in seconds for the request. If provided, - it sets the read timeout for the request. - abort_event (asyncio.Event | None): Optional asyncio event that, if set, will abort the request. - - Raises: - ConnectError: If the request is aborted or if there is an error during the request. - - Hooks: - - Executes hooks registered in `self._event_hooks["request"]` before sending the request. - - Executes hooks registered in `self._event_hooks["response"]` after receiving the response. - - Notes: - - If `abort_event` is provided and set during the request, the request will be canceled, - and a `ConnectError` with code `Code.CANCELED` will be raised. - - The response stream is unmarshaled and validated after the request is completed. - - """ - extensions = {} - if timeout: - extensions["timeout"] = {"read": timeout} - self._request_headers[CONNECT_HEADER_TIMEOUT] = str(int(timeout * 1000)) - - content_iterator = self.marshaler.marshal(messages) - - request = httpcore.Request( - method=HTTPMethod.POST, - url=httpcore.URL( - scheme=self.url.scheme, - host=self.url.host or "", - port=self.url.port, - target=self.url.raw_path, - ), - headers=list( - include_request_headers( - headers=self._request_headers, url=self.url, content=content_iterator, method=HTTPMethod.POST - ).items() - ), - content=content_iterator, - extensions=extensions, - ) - - for hook in self._event_hooks["request"]: - hook(request) - - with map_httpcore_exceptions(): - if not abort_event: - response = await self.session.pool.handle_async_request(request) - else: - request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) - abort_task = asyncio.create_task(abort_event.wait()) - - done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) - - if abort_task in done: - request_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await request_task - - raise ConnectError("request aborted", Code.CANCELED) - - abort_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await abort_task - - response = await request_task - - for hook in self._event_hooks["response"]: - hook(response) - - assert isinstance(response.stream, AsyncIterable) - self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(aiterator=response.stream) - - await self._validate_response(response) - - async def _validate_response(self, response: httpcore.Response) -> None: - response_headers = Headers(response.headers) - - if response.status != HTTPStatus.OK: - try: - await response.aread() - finally: - await response.aclose() - - raise ConnectError( - f"HTTP {response.status}", - code_from_http_status(response.status), - ) - - response_content_type = response_headers.get(HEADER_CONTENT_TYPE, "") - if not response_content_type.startswith(CONNECT_STREAMING_CONTENT_TYPE_PREFIX): - raise ConnectError( - f"invalid content-type: {response_content_type}; expecting {CONNECT_STREAMING_CONTENT_TYPE_PREFIX}", - Code.UNKNOWN, - ) - - response_codec_name = connect_codec_from_content_type(self.spec.stream_type, response_content_type) - if response_codec_name != self.codec.name: - raise ConnectError( - f"invalid content-type: {response_content_type}; expecting {CONNECT_STREAMING_CONTENT_TYPE_PREFIX + self.codec.name}", - Code.INTERNAL, - ) - - compression = response_headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, None) - if ( - compression - and compression != COMPRESSION_IDENTITY - and not any(c.name == compression for c in self.compressions) - ): - raise ConnectError( - f"unknown encoding {compression}: accepted encodings are {', '.join(c.name for c in self.compressions)}", - Code.INTERNAL, - ) - - self.unmarshaler.compression = get_compresion_from_name(compression, self.compressions) - self._response_headers.update(response_headers) - - async def aclose(self) -> None: - """Asynchronously closes the connection by invoking the `aclose` method of the unmarshaler. - - Returns: - None - - """ - await self.unmarshaler.aclose() - - -class ConnectUnaryClientConn(StreamingClientConn): - """A client connection for unary RPCs using the Connect protocol. - - Attributes: - _spec (Spec): The specification for the connection. - _peer (Peer): The peer information. - url (URL): The URL for the connection. - compressions (list[Compression]): List of supported compressions. - marshaler (ConnectUnaryRequestMarshaler): The marshaler for requests. - unmarshaler (ConnectUnaryUnmarshaler): The unmarshaler for responses. - response_content (bytes | None): The content of the response. - _response_headers (Headers): The headers of the response. - _response_trailers (Headers): The trailers of the response. - _request_headers (Headers): The headers of the request. - _event_hooks (dict[str, list[EventHook]]): Event hooks for request and response. - - """ - - session: AsyncClientSession - _spec: Spec - _peer: Peer - url: URL - compressions: list[Compression] - marshaler: ConnectUnaryRequestMarshaler - unmarshaler: ConnectUnaryUnmarshaler - response_content: bytes | None - _response_headers: Headers - _response_trailers: Headers - _request_headers: Headers - - def __init__( - self, - session: AsyncClientSession, - spec: Spec, - peer: Peer, - url: URL, - compressions: list[Compression], - request_headers: Headers, - marshaler: ConnectUnaryRequestMarshaler, - unmarshaler: ConnectUnaryUnmarshaler, - event_hooks: None | (Mapping[str, list[EventHook]]) = None, - ) -> None: - """Initialize the ConnectProtocol instance. - - Args: - session (AsyncClientSession): The session for the connection. - spec (Spec): The specification for the connection. - peer (Peer): The peer information. - url (URL): The URL for the connection. - compressions (list[Compression]): List of compression methods. - request_headers (Headers): The headers for the request. - marshaler (ConnectUnaryRequestMarshaler): The marshaler for the request. - unmarshaler (ConnectUnaryUnmarshaler): The unmarshaler for the response. - event_hooks (None | Mapping[str, list[EventHook]], optional): Event hooks for request and response. Defaults to None. - - Returns: - None - - """ - event_hooks = {} if event_hooks is None else event_hooks - - self.session = session - self._spec = spec - self._peer = peer - self.url = url - self.compressions = compressions - self.marshaler = marshaler - self.unmarshaler = unmarshaler - self.response_content = None - self._response_headers = Headers() - self._response_trailers = Headers() - self._request_headers = request_headers - self._event_hooks = { - "request": list(event_hooks.get("request", [])), - "response": list(event_hooks.get("response", [])), - } - - @property - def spec(self) -> Spec: - """Return the specification of the protocol. - - Returns: - Spec: The specification object of the protocol. - - """ - return self._spec - - @property - def peer(self) -> Peer: - """Return the peer object associated with this instance. - - :return: The peer object. - :rtype: Peer - """ - return self._peer - - async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: - """Asynchronously receives and unmarshals a message, yielding the resulting object. - - Args: - message (Any): The message to be unmarshaled. - - Yields: - Any: The unmarshaled object. - - """ - obj = await self.unmarshaler.unmarshal(message) - yield obj - - def receive(self, message: Any, _abort_event: asyncio.Event | None) -> AsyncIterator[Any]: - """Receives a message and returns an asynchronous iterator over the processed message. - - Args: - message (Any): The message to be received and processed. - - Returns: - AsyncIterator[Any]: An asynchronous iterator yielding processed message(s). - - """ - return self._receive_messages(message) - - @property - def request_headers(self) -> Headers: - """Retrieve the request headers. - - Returns: - Headers: A dictionary-like object containing the request headers. - - """ - return self._request_headers - - def on_request_send(self, fn: EventHook) -> None: - """Register a callback function to be called when a request is sent. - - Args: - fn (EventHook): The callback function to be registered. This function - will be called with the request details when a request - is sent. - - """ - self._event_hooks["request"].append(fn) - - async def send( - self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None - ) -> None: - """Send a single message asynchronously using either HTTP GET or POST, with support for timeouts and request abortion. - - Args: - messages (AsyncIterable[Any]): An asynchronous iterable yielding the message(s) to send. Only a single message is allowed. - timeout (float | None): Optional timeout in seconds for the request. If provided, sets a read timeout for the request. - abort_event (asyncio.Event | None): Optional asyncio event that, if set, aborts the request. - - Raises: - ConnectError: If the request is aborted before or during execution, or if other connection errors occur. - - Side Effects: - - Modifies request headers for timeout and content length as needed. - - Invokes registered request and response event hooks. - - Sets the unmarshaler's stream to the response stream for further processing. - - Validates the response after receiving it. - - Notes: - - If `marshaler.enable_get` is True, sends the request as HTTP GET; otherwise, uses HTTP POST. - - Handles cancellation and cleanup if the abort event is triggered during the request. - - """ - extensions = {} - if timeout: - extensions["timeout"] = {"read": timeout} - self._request_headers[CONNECT_HEADER_TIMEOUT] = str(int(timeout * 1000)) - - message = await ensure_single(messages) - data = self.marshaler.marshal(message) - - if self.marshaler.enable_get: - if self.marshaler.url is None: - raise ConnectError("url is not set", Code.INTERNAL) - - request = httpcore.Request( - method=HTTPMethod.GET, - url=httpcore.URL( - scheme=self.marshaler.url.scheme, - host=self.marshaler.url.host or "", - port=self.marshaler.url.port, - target=self.marshaler.url.raw_path_qs, - ), - headers=list( - include_request_headers( - headers=self._request_headers, url=self.url, content=data, method=HTTPMethod.GET - ).items() - ), - extensions=extensions, - ) - else: - self._request_headers[HEADER_CONTENT_LENGTH] = str(len(data)) - - request = httpcore.Request( - method=HTTPMethod.POST, - url=httpcore.URL( - scheme=self.url.scheme, - host=self.url.host or "", - port=self.url.port, - target=self.url.raw_path, - ), - headers=list( - include_request_headers( - headers=self._request_headers, url=self.url, content=data, method=HTTPMethod.POST - ).items() - ), - content=data, - extensions=extensions, - ) - - for hook in self._event_hooks["request"]: - hook(request) - - with map_httpcore_exceptions(): - if not abort_event: - response = await self.session.pool.handle_async_request(request=request) - else: - request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) - abort_task = asyncio.create_task(abort_event.wait()) - - done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) - - if abort_task in done: - request_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await request_task - - raise ConnectError("request aborted", Code.CANCELED) - - abort_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await abort_task - - response = await request_task - - for hook in self._event_hooks["response"]: - hook(response) - - assert isinstance(response.stream, AsyncIterable) - self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(response.stream) - - await self._validate_response(response) - - @property - def response_headers(self) -> Headers: - """Return the response headers. - - Returns: - Headers: A dictionary-like object containing the response headers. - - """ - return self._response_headers - - @property - def response_trailers(self) -> Headers: - """Return the response trailers. - - Response trailers are additional headers sent after the response body. - - Returns: - Headers: A dictionary containing the response trailers. - - """ - return self._response_trailers - - async def _validate_response(self, response: httpcore.Response) -> None: - self._response_headers.update(Headers(response.headers)) - - for key, value in self._response_headers.items(): - if not key.startswith(CONNECT_UNARY_TRAILER_PREFIX.lower()): - self._response_headers[key] = value - continue - - self._response_trailers[key[len(CONNECT_UNARY_TRAILER_PREFIX) :]] = value - - validate_error = connect_validate_unary_response_content_type( - self.marshaler.codec.name if self.marshaler.codec else "", - response.status, - self._response_headers.get(HEADER_CONTENT_TYPE, ""), - ) - - compression = self._response_headers.get(CONNECT_UNARY_HEADER_COMPRESSION, None) - if ( - compression - and compression != COMPRESSION_IDENTITY - and not any(c.name == compression for c in self.compressions) - ): - raise ConnectError( - f"unknown encoding {compression}: accepted encodings are {', '.join(c.name for c in self.compressions)}", - Code.INTERNAL, - ) - - self.unmarshaler.compression = get_compresion_from_name(compression, self.compressions) - - if validate_error: - - def json_ummarshal(data: bytes, _message: Any) -> Any: - return json.loads(data) - - try: - data = await self.unmarshaler.unmarshal_func(None, json_ummarshal) - wire_error = error_from_json(data, validate_error) - except ConnectError as e: - raise e - except Exception as e: - raise ConnectError( - f"HTTP {response.status}", - code_from_http_status(response.status), - ) from e - - wire_error.metadata = self._response_headers.copy() - wire_error.metadata.update(self._response_trailers) - raise wire_error - - @property - def event_hooks(self) -> dict[str, list[EventHook]]: - """Return the event hooks. - - This method returns a dictionary where the keys are strings representing - event names, and the values are lists of EventHook objects associated with - those events. - - Returns: - dict[str, list[EventHook]]: A dictionary mapping event names to lists - of EventHook objects. - - """ - return self._event_hooks - - @event_hooks.setter - def event_hooks(self, event_hooks: dict[str, list[EventHook]]) -> None: - self._event_hooks = { - "request": list(event_hooks.get("request", [])), - "response": list(event_hooks.get("response", [])), - } - - async def aclose(self) -> None: - """Asynchronously closes the connection or releases any resources held by the object. - - This method should be called when the object is no longer needed to ensure proper cleanup. - Currently, this implementation does not perform any actions, but it can be overridden in subclasses. - - Returns: - None - - """ - return - - -def connect_validate_unary_response_content_type( - request_codec_name: str, - status_code: int, - response_content_type: str, -) -> ConnectError | None: - """Validate the content type of a unary response based on the HTTP status code and method. - - Args: - request_codec_name (str): The name of the codec used for the request. - http_method (HTTPMethod): The HTTP method used for the request. - status_code (int): The HTTP status code of the response. - response_content_type (str): The content type of the response. - - Raises: - ConnectError: If the status code is not OK and the response content type is not valid. - - """ - if status_code != HTTPStatus.OK: - # Error response must be JSON-encoded. - if ( - response_content_type == CONNECT_UNARY_CONTENT_TYPE_PREFIX + CodecNameType.JSON - or response_content_type == CONNECT_UNARY_CONTENT_TYPE_PREFIX + CodecNameType.JSON_CHARSET_UTF8 - ): - return ConnectError( - f"HTTP {status_code}", - code_from_http_status(status_code), - ) - - raise ConnectError( - f"HTTP {status_code}", - code_from_http_status(status_code), - ) - - if not response_content_type.startswith(CONNECT_UNARY_CONTENT_TYPE_PREFIX): - raise ConnectError( - f"invalid content-type: {response_content_type}; expecting {CONNECT_UNARY_CONTENT_TYPE_PREFIX}", - Code.UNKNOWN, - ) - - response_codec_name = connect_codec_from_content_type(StreamType.Unary, response_content_type) - if response_codec_name == request_codec_name: - return None - - if (response_codec_name == CodecNameType.JSON and request_codec_name == CodecNameType.JSON_CHARSET_UTF8) or ( - response_codec_name == CodecNameType.JSON_CHARSET_UTF8 and request_codec_name == CodecNameType.JSON - ): - return None - - raise ConnectError( - f"invalid content-type: {response_content_type}; expecting {CONNECT_UNARY_CONTENT_TYPE_PREFIX}{request_codec_name}", - Code.INTERNAL, - ) - - -def connect_check_protocol_version(request: Request, required: bool) -> ConnectError | None: - """Check the protocol version in the request headers for POST requests. - - Args: - request (Request): The incoming HTTP request. - required (bool): Flag indicating whether the protocol version is required. - - Raises: - ValueError: If the protocol version is required but not present in the headers. - ValueError: If the protocol version is present but unsupported. - ValueError: If the HTTP method is unsupported. - - """ - match HTTPMethod(request.method): - case HTTPMethod.GET: - version = request.query_params.get(CONNECT_UNARY_CONNECT_QUERY_PARAMETER) - if required and version is None: - return ConnectError( - f'missing required parameter: set {CONNECT_UNARY_CONNECT_QUERY_PARAMETER} to "{CONNECT_UNARY_CONNECT_QUERY_VALUE}"' - ) - elif version is not None and version != CONNECT_UNARY_CONNECT_QUERY_VALUE: - return ConnectError( - f'{CONNECT_UNARY_CONNECT_QUERY_PARAMETER} must be "{CONNECT_UNARY_CONNECT_QUERY_VALUE}": get "{version}"', - ) - case HTTPMethod.POST: - version = request.headers.get(CONNECT_HEADER_PROTOCOL_VERSION, None) - if required and version is None: - return ConnectError( - f'missing required header: set {CONNECT_HEADER_PROTOCOL_VERSION} to "{CONNECT_PROTOCOL_VERSION}"', - Code.INVALID_ARGUMENT, - ) - elif version is not None and version != CONNECT_PROTOCOL_VERSION: - return ConnectError( - f'{CONNECT_HEADER_PROTOCOL_VERSION} must be "{CONNECT_PROTOCOL_VERSION}": get "{version}"', - Code.INVALID_ARGUMENT, - ) - case _: - return ConnectError(f"unsupported method: {request.method}", Code.INVALID_ARGUMENT) - - return None - - -def connect_code_to_http(code: Code) -> int: - """Convert a given `Code` enumeration to its corresponding HTTP status code. - - Args: - code (Code): The `Code` enumeration value to be converted. - - Returns: - int: The corresponding HTTP status code. - - The mapping is as follows: - - Code.CANCELED -> 499 - - Code.UNKNOWN -> 500 - - Code.INVALID_ARGUMENT -> 400 - - Code.DEADLINE_EXCEEDED -> 504 - - Code.NOT_FOUND -> 404 - - Code.ALREADY_EXISTS -> 409 - - Code.PERMISSION_DENIED -> 403 - - Code.RESOURCE_EXHAUSTED -> 429 - - Code.FAILED_PRECONDITION -> 400 - - Code.ABORTED -> 409 - - Code.OUT_OF_RANGE -> 400 - - Code.UNIMPLEMENTED -> 501 - - Code.INTERNAL -> 500 - - Code.UNAVAILABLE -> 503 - - Code.DATA_LOSS -> 500 - - Code.UNAUTHENTICATED -> 401 - - Any other code -> 500 - - """ - match code: - case Code.CANCELED: - return 499 - case Code.UNKNOWN: - return 500 - case Code.INVALID_ARGUMENT: - return 400 - case Code.DEADLINE_EXCEEDED: - return 504 - case Code.NOT_FOUND: - return 404 - case Code.ALREADY_EXISTS: - return 409 - case Code.PERMISSION_DENIED: - return 403 - case Code.RESOURCE_EXHAUSTED: - return 429 - case Code.FAILED_PRECONDITION: - return 400 - case Code.ABORTED: - return 409 - case Code.OUT_OF_RANGE: - return 400 - case Code.UNIMPLEMENTED: - return 501 - case Code.INTERNAL: - return 500 - case Code.UNAVAILABLE: - return 503 - case Code.DATA_LOSS: - return 500 - case Code.UNAUTHENTICATED: - return 401 - case _: - return 500 - - -def code_to_string(value: Code) -> str: - """Convert a Code object to its string representation. - - If the Code object has a 'name' attribute and it is not None, the method returns - the lowercase version of the 'name'. Otherwise, it returns the string representation - of the 'value' attribute. - - Args: - value (Code): The Code object to be converted to a string. - - Returns: - str: The string representation of the Code object. - - """ - if not hasattr(value, "name") or value.name is None: - return str(value.value) - - return value.name.lower() - - -_string_to_code: dict[str, Code] | None = None - - -def code_from_string(value: str) -> Code | None: - """Convert a string representation of a code to its corresponding Code enum value. - - This function uses a global dictionary to cache the mapping from string to Code enum values. - If the cache is not initialized, it populates the cache by iterating over all Code enum values - and mapping their string representations to the corresponding Code enum. - - Args: - value (str): The string representation of the code. - - Returns: - Code | None: The corresponding Code enum value if found, otherwise None. - - """ - global _string_to_code - - if _string_to_code is None: - _string_to_code = {} - for code in Code: - _string_to_code[code_to_string(code)] = code - - return _string_to_code.get(value) - - -def error_from_json(obj: dict[str, Any], fallback: ConnectError) -> ConnectError: - """Convert a JSON-serializable dictionary to a ConnectError object. - - Args: - obj (dict[str, Any]): The dictionary representing the error in JSON format. - fallback (ConnectError): A fallback ConnectError object to use in case of missing or invalid fields. - - Returns: - ConnectError: The ConnectError object converted from the dictionary. - - Raises: - ConnectError: If the dictionary is missing required fields or contains invalid values, - a ConnectError is raised with an appropriate error message and code. - - """ - code = fallback.code - if "code" in obj: - code = code_from_string(obj["code"]) or code - - message = obj.get("message", "") - details = obj.get("details", []) - - error = ConnectError(message, code, wire_error=True) - - for detail in details: - type_name = detail.get("type", None) - value = detail.get("value", None) - - if type_name is None: - raise fallback - if value is None: - raise fallback - - type_name = type_name if "/" in type_name else DEFAULT_ANY_RESOLVER_PREFIX + type_name - try: - decoded = base64.b64decode(value.encode() + b"=" * (4 - len(value) % 4)) - except Exception as e: - raise fallback from e - - error.details.append( - ErrorDetail(pb_any=any_pb2.Any(type_url=type_name, value=decoded), wire_json=json.dumps(detail)) - ) - - return error - - -def end_stream_from_bytes(data: bytes) -> tuple[ConnectError | None, Headers]: - """Parse a byte stream to extract metadata and error information. - - Args: - data (bytes): The byte stream to be parsed. - - Returns: - tuple[ConnectError | None, Headers]: A tuple containing an optional ConnectError - and a Headers object with the parsed metadata. - - Raises: - ConnectError: If the byte stream is invalid or the metadata format is incorrect. - - """ - parse_error = ConnectError("invalid end stream", Code.UNKNOWN) - try: - obj = json.loads(data) - except Exception as e: - raise ConnectError( - "invalid end stream", - Code.UNKNOWN, - ) from e - - metadata = Headers() - if "metadata" in obj: - if not isinstance(obj["metadata"], dict) or not all( - isinstance(k, str) and isinstance(v, list) for k, v in obj["metadata"].items() - ): - raise ConnectError( - "invalid end stream", - Code.UNKNOWN, - ) - - for key, values in obj["metadata"].items(): - value = ", ".join(values) - metadata[key] = value - - if "error" in obj and obj["error"] is not None: - error = error_from_json(obj["error"], parse_error) - return error, metadata - else: - return None, metadata - - -def end_stream_to_json(error: ConnectError | None, trailers: Headers) -> dict[str, Any]: - """Convert the end of a stream to a JSON-serializable dictionary. - - Args: - error (ConnectError | None): An optional error object that may contain metadata. - trailers (Headers): Headers object containing metadata. - - Returns: - dict[str, Any]: A dictionary containing the error and metadata information in JSON-serializable format. - - """ - json_obj = {} - - metadata = Headers(trailers.copy()) - if error: - json_obj["error"] = error_to_json(error) - metadata.update(error.metadata.copy()) - - if len(metadata) > 0: - json_obj["metadata"] = {k: v.split(", ") for k, v in metadata.items()} - - return json_obj - - -def error_to_json(error: ConnectError) -> dict[str, Any]: - """Convert a ConnectError object to a JSON-serializable dictionary. - - Args: - error (ConnectError): The error object to convert. - - Returns: - dict[str, Any]: A dictionary representing the error in JSON format. - - "code" (str): The error code as a string. - - "message" (str, optional): The raw error message, if available. - - "details" (list[dict[str, Any]], optional): A list of dictionaries containing error details, if available. - Each detail dictionary contains: - - "type" (str): The type name of the detail. - - "value" (str): The base64-encoded value of the detail. - - "debug" (str, optional): The JSON-encoded debug information, if available. - - """ - obj: dict[str, Any] = {"code": error.code.string()} - - if len(error.raw_message) > 0: - obj["message"] = error.raw_message - - if len(error.details) > 0: - wires = [] - for detail in error.details: - wire: dict[str, Any] = { - "type": detail.pb_any.TypeName(), - "value": base64.b64encode(detail.pb_any.value).decode().rstrip("="), - } - - with contextlib.suppress(Exception): - meg = detail.get_inner() - wire["debug"] = json_format.MessageToDict(meg) - - wires.append(wire) - - obj["details"] = wires - - return obj - - -def error_to_json_bytes(error: ConnectError) -> bytes: - """Serialize a ConnectError object to a JSON-encoded byte string. - - Args: - error (ConnectError): The ConnectError object to serialize. - - Returns: - bytes: The JSON-encoded byte string representation of the error. - - Raises: - ConnectError: If serialization fails, a ConnectError is raised with an - appropriate error message and code. - - """ - try: - json_obj = error_to_json(error) - json_str = json.dumps(json_obj) - - return json_str.encode() - except Exception as e: - raise ConnectError(f"failed to serialize Connect Error: {e}", Code.INTERNAL) from e diff --git a/src/connect/protocol_connect/__init__.py b/src/connect/protocol_connect/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/connect/protocol_connect/connect_client.py b/src/connect/protocol_connect/connect_client.py new file mode 100644 index 0000000..57c076a --- /dev/null +++ b/src/connect/protocol_connect/connect_client.py @@ -0,0 +1,851 @@ +"""Provides classes and functions for handling protocol connections.""" + +import asyncio +import contextlib +import json +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Callable, + Mapping, +) +from http import HTTPMethod, HTTPStatus +from typing import Any + +import httpcore +from yarl import URL + +from connect.byte_stream import HTTPCoreResponseAsyncByteStream +from connect.code import Code +from connect.codec import Codec, StableCodec +from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name +from connect.connect import ( + Address, + Peer, + Spec, + StreamingClientConn, + StreamType, + ensure_single, +) +from connect.error import ConnectError +from connect.headers import Headers, include_request_headers +from connect.idempotency_level import IdempotencyLevel +from connect.protocol import ( + HEADER_CONTENT_LENGTH, + HEADER_CONTENT_TYPE, + HEADER_USER_AGENT, + PROTOCOL_CONNECT, + ProtocolClient, + ProtocolClientParams, + code_from_http_status, +) +from connect.protocol_connect.constants import ( + CONNECT_HEADER_PROTOCOL_VERSION, + CONNECT_HEADER_TIMEOUT, + CONNECT_PROTOCOL_VERSION, + CONNECT_STREAMING_CONTENT_TYPE_PREFIX, + CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, + CONNECT_STREAMING_HEADER_COMPRESSION, + CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION, + CONNECT_UNARY_HEADER_COMPRESSION, + CONNECT_UNARY_TRAILER_PREFIX, + DEFAULT_CONNECT_USER_AGENT, +) +from connect.protocol_connect.content_type import ( + connect_codec_from_content_type, + connect_content_type_from_codec_name, + connect_validate_unary_response_content_type, +) +from connect.protocol_connect.error_json import error_from_json +from connect.protocol_connect.marshaler import ConnectStreamingMarshaler, ConnectUnaryRequestMarshaler +from connect.protocol_connect.unmarshaler import ConnectStreamingUnmarshaler, ConnectUnaryUnmarshaler +from connect.session import AsyncClientSession +from connect.utils import ( + map_httpcore_exceptions, +) + +EventHook = Callable[..., Any] + + +class ConnectClient(ProtocolClient): + """ConnectClient is a client for handling connections using the Connect protocol. + + Attributes: + params (ProtocolClientParams): Parameters for the protocol client. + _peer (Peer): The peer object representing the connection endpoint. + + """ + + params: ProtocolClientParams + _peer: Peer + + def __init__(self, params: ProtocolClientParams) -> None: + """Initialize the ProtocolConnect instance with the given parameters. + + Args: + params (ProtocolClientParams): The parameters required to initialize the ProtocolConnect instance. + + """ + self.params = params + self._peer = Peer( + address=Address(host=params.url.host or "", port=params.url.port or 80), + protocol=PROTOCOL_CONNECT, + query={}, + ) + + @property + def peer(self) -> Peer: + """Return the peer associated with this instance. + + :return: The peer associated with this instance. + :rtype: Peer + """ + return self._peer + + def write_request_headers(self, stream_type: StreamType, headers: Headers) -> None: + """Write the necessary request headers to the provided headers dictionary. + + This method ensures that the headers dictionary contains the required headers + for a request, including user agent, protocol version, content type, and + optionally, compression settings. + + Args: + stream_type (StreamType): The type of stream for the request. + headers (Headers): The dictionary of headers to be updated. + + Returns: + None + + """ + if headers.get(HEADER_USER_AGENT, None) is None: + headers[HEADER_USER_AGENT] = DEFAULT_CONNECT_USER_AGENT + + headers[CONNECT_HEADER_PROTOCOL_VERSION] = CONNECT_PROTOCOL_VERSION + headers[HEADER_CONTENT_TYPE] = connect_content_type_from_codec_name(stream_type, self.params.codec.name) + + accept_compression_header = CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION + if stream_type != StreamType.Unary: + headers[CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION] = COMPRESSION_IDENTITY + accept_compression_header = CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION + if self.params.compression_name and self.params.compression_name != COMPRESSION_IDENTITY: + headers[CONNECT_STREAMING_HEADER_COMPRESSION] = self.params.compression_name + + if self.params.compressions: + headers[accept_compression_header] = ", ".join(c.name for c in self.params.compressions) + + def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: + """Establish a unary client connection with the given specifications and headers. + + Args: + spec (Spec): The specification for the connection. + headers (Headers): The headers to be included in the request. + + Returns: + UnaryClientConn: The established unary client connection. + + """ + conn: StreamingClientConn + if spec.stream_type == StreamType.Unary: + conn = ConnectUnaryClientConn( + session=self.params.session, + spec=spec, + peer=self.peer, + url=self.params.url, + compressions=self.params.compressions, + request_headers=headers, + marshaler=ConnectUnaryRequestMarshaler( + codec=self.params.codec, + compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + headers=headers, + ), + unmarshaler=ConnectUnaryUnmarshaler( + codec=self.params.codec, + read_max_bytes=self.params.read_max_bytes, + ), + ) + if spec.idempotency_level == IdempotencyLevel.NO_SIDE_EFFECTS: + conn.marshaler.enable_get = self.params.enable_get + conn.marshaler.url = self.params.url + if isinstance(self.params.codec, StableCodec): + conn.marshaler.stable_codec = self.params.codec + else: + conn = ConnectStreamingClientConn( + session=self.params.session, + spec=spec, + peer=self.peer, + url=self.params.url, + codec=self.params.codec, + compressions=self.params.compressions, + request_headers=headers, + marshaler=ConnectStreamingMarshaler( + codec=self.params.codec, + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), + ), + unmarshaler=ConnectStreamingUnmarshaler( + codec=self.params.codec, + read_max_bytes=self.params.read_max_bytes, + ), + ) + + return conn + + +class ConnectUnaryClientConn(StreamingClientConn): + """A client connection for unary RPCs using the Connect protocol. + + Attributes: + _spec (Spec): The specification for the connection. + _peer (Peer): The peer information. + url (URL): The URL for the connection. + compressions (list[Compression]): List of supported compressions. + marshaler (ConnectUnaryRequestMarshaler): The marshaler for requests. + unmarshaler (ConnectUnaryUnmarshaler): The unmarshaler for responses. + response_content (bytes | None): The content of the response. + _response_headers (Headers): The headers of the response. + _response_trailers (Headers): The trailers of the response. + _request_headers (Headers): The headers of the request. + _event_hooks (dict[str, list[EventHook]]): Event hooks for request and response. + + """ + + session: AsyncClientSession + _spec: Spec + _peer: Peer + url: URL + compressions: list[Compression] + marshaler: ConnectUnaryRequestMarshaler + unmarshaler: ConnectUnaryUnmarshaler + response_content: bytes | None + _response_headers: Headers + _response_trailers: Headers + _request_headers: Headers + + def __init__( + self, + session: AsyncClientSession, + spec: Spec, + peer: Peer, + url: URL, + compressions: list[Compression], + request_headers: Headers, + marshaler: ConnectUnaryRequestMarshaler, + unmarshaler: ConnectUnaryUnmarshaler, + event_hooks: None | (Mapping[str, list[EventHook]]) = None, + ) -> None: + """Initialize the ConnectProtocol instance. + + Args: + session (AsyncClientSession): The session for the connection. + spec (Spec): The specification for the connection. + peer (Peer): The peer information. + url (URL): The URL for the connection. + compressions (list[Compression]): List of compression methods. + request_headers (Headers): The headers for the request. + marshaler (ConnectUnaryRequestMarshaler): The marshaler for the request. + unmarshaler (ConnectUnaryUnmarshaler): The unmarshaler for the response. + event_hooks (None | Mapping[str, list[EventHook]], optional): Event hooks for request and response. Defaults to None. + + Returns: + None + + """ + event_hooks = {} if event_hooks is None else event_hooks + + self.session = session + self._spec = spec + self._peer = peer + self.url = url + self.compressions = compressions + self.marshaler = marshaler + self.unmarshaler = unmarshaler + self.response_content = None + self._response_headers = Headers() + self._response_trailers = Headers() + self._request_headers = request_headers + self._event_hooks = { + "request": list(event_hooks.get("request", [])), + "response": list(event_hooks.get("response", [])), + } + + @property + def spec(self) -> Spec: + """Return the specification of the protocol. + + Returns: + Spec: The specification object of the protocol. + + """ + return self._spec + + @property + def peer(self) -> Peer: + """Return the peer object associated with this instance. + + :return: The peer object. + :rtype: Peer + """ + return self._peer + + async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: + """Asynchronously receives and unmarshals a message, yielding the resulting object. + + Args: + message (Any): The message to be unmarshaled. + + Yields: + Any: The unmarshaled object. + + """ + obj = await self.unmarshaler.unmarshal(message) + yield obj + + def receive(self, message: Any, _abort_event: asyncio.Event | None) -> AsyncIterator[Any]: + """Receives a message and returns an asynchronous iterator over the processed message. + + Args: + message (Any): The message to be received and processed. + + Returns: + AsyncIterator[Any]: An asynchronous iterator yielding processed message(s). + + """ + return self._receive_messages(message) + + @property + def request_headers(self) -> Headers: + """Retrieve the request headers. + + Returns: + Headers: A dictionary-like object containing the request headers. + + """ + return self._request_headers + + def on_request_send(self, fn: EventHook) -> None: + """Register a callback function to be called when a request is sent. + + Args: + fn (EventHook): The callback function to be registered. This function + will be called with the request details when a request + is sent. + + """ + self._event_hooks["request"].append(fn) + + async def send( + self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None + ) -> None: + """Send a single message asynchronously using either HTTP GET or POST, with support for timeouts and request abortion. + + Args: + messages (AsyncIterable[Any]): An asynchronous iterable yielding the message(s) to send. Only a single message is allowed. + timeout (float | None): Optional timeout in seconds for the request. If provided, sets a read timeout for the request. + abort_event (asyncio.Event | None): Optional asyncio event that, if set, aborts the request. + + Raises: + ConnectError: If the request is aborted before or during execution, or if other connection errors occur. + + Side Effects: + - Modifies request headers for timeout and content length as needed. + - Invokes registered request and response event hooks. + - Sets the unmarshaler's stream to the response stream for further processing. + - Validates the response after receiving it. + + Notes: + - If `marshaler.enable_get` is True, sends the request as HTTP GET; otherwise, uses HTTP POST. + - Handles cancellation and cleanup if the abort event is triggered during the request. + + """ + extensions = {} + if timeout: + extensions["timeout"] = {"read": timeout} + self._request_headers[CONNECT_HEADER_TIMEOUT] = str(int(timeout * 1000)) + + message = await ensure_single(messages) + data = self.marshaler.marshal(message) + + if self.marshaler.enable_get: + if self.marshaler.url is None: + raise ConnectError("url is not set", Code.INTERNAL) + + request = httpcore.Request( + method=HTTPMethod.GET, + url=httpcore.URL( + scheme=self.marshaler.url.scheme, + host=self.marshaler.url.host or "", + port=self.marshaler.url.port, + target=self.marshaler.url.raw_path_qs, + ), + headers=list( + include_request_headers( + headers=self._request_headers, url=self.url, content=data, method=HTTPMethod.GET + ).items() + ), + extensions=extensions, + ) + else: + self._request_headers[HEADER_CONTENT_LENGTH] = str(len(data)) + + request = httpcore.Request( + method=HTTPMethod.POST, + url=httpcore.URL( + scheme=self.url.scheme, + host=self.url.host or "", + port=self.url.port, + target=self.url.raw_path, + ), + headers=list( + include_request_headers( + headers=self._request_headers, url=self.url, content=data, method=HTTPMethod.POST + ).items() + ), + content=data, + extensions=extensions, + ) + + for hook in self._event_hooks["request"]: + hook(request) + + with map_httpcore_exceptions(): + if not abort_event: + response = await self.session.pool.handle_async_request(request=request) + else: + request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) + abort_task = asyncio.create_task(abort_event.wait()) + + done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) + + if abort_task in done: + request_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await request_task + + raise ConnectError("request aborted", Code.CANCELED) + + abort_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await abort_task + + response = await request_task + + for hook in self._event_hooks["response"]: + hook(response) + + assert isinstance(response.stream, AsyncIterable) + self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(response.stream) + + await self._validate_response(response) + + @property + def response_headers(self) -> Headers: + """Return the response headers. + + Returns: + Headers: A dictionary-like object containing the response headers. + + """ + return self._response_headers + + @property + def response_trailers(self) -> Headers: + """Return the response trailers. + + Response trailers are additional headers sent after the response body. + + Returns: + Headers: A dictionary containing the response trailers. + + """ + return self._response_trailers + + async def _validate_response(self, response: httpcore.Response) -> None: + self._response_headers.update(Headers(response.headers)) + + for key, value in self._response_headers.items(): + if not key.startswith(CONNECT_UNARY_TRAILER_PREFIX.lower()): + self._response_headers[key] = value + continue + + self._response_trailers[key[len(CONNECT_UNARY_TRAILER_PREFIX) :]] = value + + validate_error = connect_validate_unary_response_content_type( + self.marshaler.codec.name if self.marshaler.codec else "", + response.status, + self._response_headers.get(HEADER_CONTENT_TYPE, ""), + ) + + compression = self._response_headers.get(CONNECT_UNARY_HEADER_COMPRESSION, None) + if ( + compression + and compression != COMPRESSION_IDENTITY + and not any(c.name == compression for c in self.compressions) + ): + raise ConnectError( + f"unknown encoding {compression}: accepted encodings are {', '.join(c.name for c in self.compressions)}", + Code.INTERNAL, + ) + + self.unmarshaler.compression = get_compresion_from_name(compression, self.compressions) + + if validate_error: + + def json_ummarshal(data: bytes, _message: Any) -> Any: + return json.loads(data) + + try: + data = await self.unmarshaler.unmarshal_func(None, json_ummarshal) + wire_error = error_from_json(data, validate_error) + except ConnectError as e: + raise e + except Exception as e: + raise ConnectError( + f"HTTP {response.status}", + code_from_http_status(response.status), + ) from e + + wire_error.metadata = self._response_headers.copy() + wire_error.metadata.update(self._response_trailers) + raise wire_error + + @property + def event_hooks(self) -> dict[str, list[EventHook]]: + """Return the event hooks. + + This method returns a dictionary where the keys are strings representing + event names, and the values are lists of EventHook objects associated with + those events. + + Returns: + dict[str, list[EventHook]]: A dictionary mapping event names to lists + of EventHook objects. + + """ + return self._event_hooks + + @event_hooks.setter + def event_hooks(self, event_hooks: dict[str, list[EventHook]]) -> None: + self._event_hooks = { + "request": list(event_hooks.get("request", [])), + "response": list(event_hooks.get("response", [])), + } + + async def aclose(self) -> None: + """Asynchronously closes the connection or releases any resources held by the object. + + This method should be called when the object is no longer needed to ensure proper cleanup. + Currently, this implementation does not perform any actions, but it can be overridden in subclasses. + + Returns: + None + + """ + return + + +class ConnectStreamingClientConn(StreamingClientConn): + """ConnectStreamingClientConn is a class that manages a streaming client connection using the Connect protocol.""" + + _spec: Spec + _peer: Peer + url: URL + codec: Codec + compressions: list[Compression] + marshaler: ConnectStreamingMarshaler + unmarshaler: ConnectStreamingUnmarshaler + response_content: bytes | None + _response_headers: Headers + _response_trailers: Headers + _request_headers: Headers + + def __init__( + self, + session: AsyncClientSession, + spec: Spec, + peer: Peer, + url: URL, + codec: Codec, + compressions: list[Compression], + request_headers: Headers, + marshaler: ConnectStreamingMarshaler, + unmarshaler: ConnectStreamingUnmarshaler, + event_hooks: None | (Mapping[str, list[EventHook]]) = None, + ) -> None: + """Initialize a new instance of the class. + + Args: + session (AsyncClientSession): The session object for the connection. + spec (Spec): The specification object. + peer (Peer): The peer object. + url (URL): The URL for the connection. + codec (Codec): The codec to be used for encoding and decoding. + compressions (list[Compression]): List of compression methods. + request_headers (Headers): The headers for the request. + marshaler (ConnectStreamingMarshaler): The marshaler for streaming. + unmarshaler (ConnectStreamingUnmarshaler): The unmarshaler for streaming. + event_hooks (None | Mapping[str, list[EventHook]], optional): Event hooks for request and response. Defaults to None. + + Returns: + None + + """ + event_hooks = {} if event_hooks is None else event_hooks + + self.session = session + self._spec = spec + self._peer = peer + self.url = url + self.codec = codec + self.compressions = compressions + self.marshaler = marshaler + self.unmarshaler = unmarshaler + self.response_content = None + self._response_headers = Headers() + self._response_trailers = Headers() + self._request_headers = request_headers + self._event_hooks = { + "request": list(event_hooks.get("request", [])), + "response": list(event_hooks.get("response", [])), + } + + @property + def spec(self) -> Spec: + """Return the specification of the protocol. + + Returns: + Spec: The specification object of the protocol. + + """ + return self._spec + + @property + def peer(self) -> Peer: + """Return the peer object associated with this instance. + + :return: The peer object. + :rtype: Peer + """ + return self._peer + + @property + def request_headers(self) -> Headers: + """Retrieve the request headers. + + Returns: + Headers: A dictionary-like object containing the request headers. + + """ + return self._request_headers + + @property + def response_headers(self) -> Headers: + """Return the response headers. + + Returns: + Headers: A dictionary-like object containing the response headers. + + """ + return self._response_headers + + @property + def response_trailers(self) -> Headers: + """Return the response trailers. + + Response trailers are additional headers sent after the response body. + + Returns: + Headers: A dictionary containing the response trailers. + + """ + return self._response_trailers + + def on_request_send(self, fn: EventHook) -> None: + """Register a callback function to be called when a request is sent. + + Args: + fn (EventHook): The callback function to be registered. This function + will be called with the request details when a request + is sent. + + """ + self._event_hooks["request"].append(fn) + + async def receive(self, message: Any, abort_event: asyncio.Event | None = None) -> AsyncIterator[Any]: + """Asynchronously receives and processes a message. + + Args: + message (Any): The message to be processed. + abort_event (asyncio.Event | None): Event to signal abortion of the operation. + + Yields: + Any: Objects obtained from unmarshaling the message. + + Raises: + ConnectError: If stream is malformed or aborted. + + """ + end_stream_received = False + + async for obj, end in self.unmarshaler.unmarshal(message): + if abort_event and abort_event.is_set(): + raise ConnectError("receive operation aborted", Code.CANCELED) + + if end: + if end_stream_received: + raise ConnectError("received extra end stream message", Code.INVALID_ARGUMENT) + + end_stream_received = True + error = self.unmarshaler.end_stream_error + if error: + for key, value in self.response_headers.items(): + error.metadata[key] = value + error.metadata.update(self.unmarshaler.trailers.copy()) + raise error + + for key, value in self.unmarshaler.trailers.items(): + self.response_trailers[key] = value + + continue + + if end_stream_received: + raise ConnectError("received message after end stream", Code.INVALID_ARGUMENT) + + yield obj + + if not end_stream_received: + raise ConnectError("missing end stream message", Code.INVALID_ARGUMENT) + + async def send( + self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None + ) -> None: + """Send an asynchronous HTTP POST request with the given messages and handle the response. + + Args: + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be sent. + timeout (float | None): Optional timeout value in seconds for the request. If provided, + it sets the read timeout for the request. + abort_event (asyncio.Event | None): Optional asyncio event that, if set, will abort the request. + + Raises: + ConnectError: If the request is aborted or if there is an error during the request. + + Hooks: + - Executes hooks registered in `self._event_hooks["request"]` before sending the request. + - Executes hooks registered in `self._event_hooks["response"]` after receiving the response. + + Notes: + - If `abort_event` is provided and set during the request, the request will be canceled, + and a `ConnectError` with code `Code.CANCELED` will be raised. + - The response stream is unmarshaled and validated after the request is completed. + + """ + extensions = {} + if timeout: + extensions["timeout"] = {"read": timeout} + self._request_headers[CONNECT_HEADER_TIMEOUT] = str(int(timeout * 1000)) + + content_iterator = self.marshaler.marshal(messages) + + request = httpcore.Request( + method=HTTPMethod.POST, + url=httpcore.URL( + scheme=self.url.scheme, + host=self.url.host or "", + port=self.url.port, + target=self.url.raw_path, + ), + headers=list( + include_request_headers( + headers=self._request_headers, url=self.url, content=content_iterator, method=HTTPMethod.POST + ).items() + ), + content=content_iterator, + extensions=extensions, + ) + + for hook in self._event_hooks["request"]: + hook(request) + + with map_httpcore_exceptions(): + if not abort_event: + response = await self.session.pool.handle_async_request(request) + else: + request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) + abort_task = asyncio.create_task(abort_event.wait()) + + done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) + + if abort_task in done: + request_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await request_task + + raise ConnectError("request aborted", Code.CANCELED) + + abort_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await abort_task + + response = await request_task + + for hook in self._event_hooks["response"]: + hook(response) + + assert isinstance(response.stream, AsyncIterable) + self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(aiterator=response.stream) + + await self._validate_response(response) + + async def _validate_response(self, response: httpcore.Response) -> None: + response_headers = Headers(response.headers) + + if response.status != HTTPStatus.OK: + try: + await response.aread() + finally: + await response.aclose() + + raise ConnectError( + f"HTTP {response.status}", + code_from_http_status(response.status), + ) + + response_content_type = response_headers.get(HEADER_CONTENT_TYPE, "") + if not response_content_type.startswith(CONNECT_STREAMING_CONTENT_TYPE_PREFIX): + raise ConnectError( + f"invalid content-type: {response_content_type}; expecting {CONNECT_STREAMING_CONTENT_TYPE_PREFIX}", + Code.UNKNOWN, + ) + + response_codec_name = connect_codec_from_content_type(self.spec.stream_type, response_content_type) + if response_codec_name != self.codec.name: + raise ConnectError( + f"invalid content-type: {response_content_type}; expecting {CONNECT_STREAMING_CONTENT_TYPE_PREFIX + self.codec.name}", + Code.INTERNAL, + ) + + compression = response_headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, None) + if ( + compression + and compression != COMPRESSION_IDENTITY + and not any(c.name == compression for c in self.compressions) + ): + raise ConnectError( + f"unknown encoding {compression}: accepted encodings are {', '.join(c.name for c in self.compressions)}", + Code.INTERNAL, + ) + + self.unmarshaler.compression = get_compresion_from_name(compression, self.compressions) + self._response_headers.update(response_headers) + + async def aclose(self) -> None: + """Asynchronously closes the connection by invoking the `aclose` method of the unmarshaler. + + Returns: + None + + """ + await self.unmarshaler.aclose() diff --git a/src/connect/protocol_connect/connect_handler.py b/src/connect/protocol_connect/connect_handler.py new file mode 100644 index 0000000..28bca96 --- /dev/null +++ b/src/connect/protocol_connect/connect_handler.py @@ -0,0 +1,772 @@ +import base64 +import json +from collections.abc import ( + AsyncIterable, + AsyncIterator, +) +from http import HTTPMethod, HTTPStatus +from typing import Any +from urllib.parse import unquote + +from connect.code import Code +from connect.compression import COMPRESSION_IDENTITY +from connect.connect import ( + Address, + Peer, + Spec, + StreamingHandlerConn, + StreamType, + ensure_single, +) +from connect.error import ConnectError +from connect.headers import Headers +from connect.protocol import ( + HEADER_CONTENT_TYPE, + PROTOCOL_CONNECT, + ProtocolHandler, + ProtocolHandlerParams, + exclude_protocol_headers, + negotiate_compression, +) +from connect.protocol_connect.constants import ( + CONNECT_HEADER_PROTOCOL_VERSION, + CONNECT_HEADER_TIMEOUT, + CONNECT_PROTOCOL_VERSION, + CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, + CONNECT_STREAMING_HEADER_COMPRESSION, + CONNECT_UNARY_BASE64_QUERY_PARAMETER, + CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER, + CONNECT_UNARY_CONNECT_QUERY_PARAMETER, + CONNECT_UNARY_CONNECT_QUERY_VALUE, + CONNECT_UNARY_CONTENT_TYPE_JSON, + CONNECT_UNARY_ENCODING_QUERY_PARAMETER, + CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION, + CONNECT_UNARY_HEADER_COMPRESSION, + CONNECT_UNARY_MESSAGE_QUERY_PARAMETER, + CONNECT_UNARY_TRAILER_PREFIX, +) +from connect.protocol_connect.content_type import connect_codec_from_content_type, connect_content_type_from_codec_name +from connect.protocol_connect.error_code import connect_code_to_http +from connect.protocol_connect.error_json import error_to_json +from connect.protocol_connect.marshaler import ConnectStreamingMarshaler, ConnectUnaryMarshaler +from connect.protocol_connect.unmarshaler import ConnectStreamingUnmarshaler, ConnectUnaryUnmarshaler +from connect.request import Request +from connect.response import Response +from connect.streaming_response import StreamingResponse +from connect.utils import ( + aiterate, +) +from connect.writer import ServerResponseWriter + + +class ConnectHandler(ProtocolHandler): + """A handler for managing protocol connections. + + Attributes: + params (ProtocolHandlerParams): Parameters for the protocol handler. + __methods (list[HTTPMethod]): List of HTTP methods supported by the handler. + accept (list[str]): List of accepted content types. + + """ + + params: ProtocolHandlerParams + _methods: list[HTTPMethod] + accept: list[str] + + def __init__(self, params: ProtocolHandlerParams, methods: list[HTTPMethod], accept: list[str]) -> None: + """Initialize the ProtocolConnect instance. + + Args: + params (ProtocolHandlerParams): The parameters for the protocol handler. + methods (list[HTTPMethod]): A list of HTTP methods. + accept (list[str]): A list of accepted content types. + + """ + self.params = params + self._methods = methods + self.accept = accept + + @property + def methods(self) -> list[HTTPMethod]: + """Return the list of HTTP methods. + + Returns: + list[HTTPMethod]: A list of HTTP methods. + + """ + return self._methods + + def content_types(self) -> list[str]: + """Handle content types. + + This method currently does nothing and serves as a placeholder for future + implementation related to content types. + + """ + return self.accept + + def can_handle_payload(self, request: Request, content_type: str) -> bool: + """Check if the handler can handle the payload.""" + if HTTPMethod(request.method) == HTTPMethod.GET: + codec_name = request.query_params.get(CONNECT_UNARY_ENCODING_QUERY_PARAMETER, "") + content_type = connect_content_type_from_codec_name(self.params.spec.stream_type, codec_name) + + return content_type in self.accept + + async def conn( + self, + request: Request, + response_headers: Headers, + response_trailers: Headers, + writer: ServerResponseWriter, + ) -> StreamingHandlerConn | None: + """Handle a connection request. + + Args: + request (Request): The incoming request object. + response_headers (Headers): The headers to be sent in the response. + response_trailers (Headers): The trailers to be sent in the response. + writer (ServerResponseWriter): The writer used to send the response. + is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. + + Returns: + StreamingHandlerConn | None: The connection handler or None if not implemented. + + Raises: + ConnectError: If there is an error in negotiating compression, protocol version, or message encoding. + + """ + query_params = request.query_params + + if self.params.spec.stream_type == StreamType.Unary: + if HTTPMethod(request.method) == HTTPMethod.GET: + content_encoding = query_params.get(CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER, None) + else: + content_encoding = request.headers.get(CONNECT_UNARY_HEADER_COMPRESSION, None) + accept_encoding = request.headers.get(CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION, None) + else: + content_encoding = request.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, None) + accept_encoding = request.headers.get(CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, None) + + request_compression, response_compression, error = negotiate_compression( + self.params.compressions, content_encoding, accept_encoding + ) + + if error is None: + required = self.params.require_connect_protocol_header and self.params.spec.stream_type == StreamType.Unary + error = connect_check_protocol_version(request, required) + + if HTTPMethod(request.method) == HTTPMethod.GET: + encoding = query_params.get(CONNECT_UNARY_ENCODING_QUERY_PARAMETER, "") + message = query_params.get(CONNECT_UNARY_MESSAGE_QUERY_PARAMETER, "") + if error is None and encoding == "": + error = ConnectError( + f"missing {CONNECT_UNARY_ENCODING_QUERY_PARAMETER} parameter", + Code.INVALID_ARGUMENT, + ) + if error is None and message == "": + error = ConnectError( + f"missing {CONNECT_UNARY_MESSAGE_QUERY_PARAMETER} parameter", + Code.INVALID_ARGUMENT, + ) + + if query_params.get(CONNECT_UNARY_BASE64_QUERY_PARAMETER) == "1": + message_unquoted = unquote(message) + decoded = base64.urlsafe_b64decode(message_unquoted + "=" * (-len(message_unquoted) % 4)) + else: + decoded = message.encode() + + request_stream = aiterate([decoded]) + codec_name = encoding + content_type = connect_content_type_from_codec_name(self.params.spec.stream_type, codec_name) + else: + request_stream = request.stream() + content_type = request.headers.get(HEADER_CONTENT_TYPE, "") + codec_name = connect_codec_from_content_type(self.params.spec.stream_type, content_type) + + codec = self.params.codecs.get(codec_name) + if error is None and codec is None: + error = ConnectError( + f"invalid message encoding: {codec_name}", + Code.INVALID_ARGUMENT, + ) + + response_headers[HEADER_CONTENT_TYPE] = content_type + + if self.params.spec.stream_type == StreamType.Unary: + response_headers[CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION] = ( + f"{', '.join(c.name for c in self.params.compressions)}" + ) + else: + if response_compression and response_compression.name != COMPRESSION_IDENTITY: + response_headers[CONNECT_STREAMING_HEADER_COMPRESSION] = response_compression.name + + response_headers[CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION] = ( + f"{', '.join(c.name for c in self.params.compressions)}" + ) + + peer = Peer( + address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, + protocol=PROTOCOL_CONNECT, + query=request.query_params, + ) + + conn: StreamingHandlerConn + if self.params.spec.stream_type == StreamType.Unary: + conn = ConnectUnaryHandlerConn( + writer=writer, + request=request, + peer=peer, + spec=self.params.spec, + marshaler=ConnectUnaryMarshaler( + codec=codec, + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + compression=response_compression, + headers=response_headers, + ), + unmarshaler=ConnectUnaryUnmarshaler( + stream=request_stream, + codec=codec, + compression=request_compression, + read_max_bytes=self.params.read_max_bytes, + ), + request_headers=Headers(request.headers, encoding="latin-1"), + response_headers=response_headers, + response_trailers=response_trailers, + ) + + else: + conn = ConnectStreamingHandlerConn( + writer=writer, + request=request, + peer=peer, + spec=self.params.spec, + marshaler=ConnectStreamingMarshaler( + codec=codec, + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + compression=response_compression, + ), + unmarshaler=ConnectStreamingUnmarshaler( + stream=request.stream(), + codec=codec, + compression=request_compression, + read_max_bytes=self.params.read_max_bytes, + ), + request_headers=Headers(request.headers, encoding="latin-1"), + response_headers=response_headers, + response_trailers=response_trailers, + ) + + if error: + await conn.send_error(error) + return None + + return conn + + +class ConnectUnaryHandlerConn(StreamingHandlerConn): + """ConnectUnaryHandlerConn is a handler connection class for unary RPCs in the Connect protocol. + + Attributes: + request (Request): The incoming request object. + marshaler (ConnectUnaryMarshaler): An instance of ConnectUnaryMarshaler used to marshal messages. + unmarshaler (ConnectUnaryUnmarshaler): An instance of ConnectUnaryUnmarshaler used to unmarshal messages. + headers (Headers): The headers for the response. + + """ + + writer: ServerResponseWriter + request: Request + _peer: Peer + _spec: Spec + marshaler: ConnectUnaryMarshaler + unmarshaler: ConnectUnaryUnmarshaler + _request_headers: Headers + _response_headers: Headers + _response_trailers: Headers + + def __init__( + self, + writer: ServerResponseWriter, + request: Request, + peer: Peer, + spec: Spec, + marshaler: ConnectUnaryMarshaler, + unmarshaler: ConnectUnaryUnmarshaler, + request_headers: Headers, + response_headers: Headers, + response_trailers: Headers | None = None, + ) -> None: + """Initialize the protocol connection. + + Args: + writer (ServerResponseWriter): The writer to send the response. + request (Request): The incoming request object. + peer (Peer): The peer information. + spec (Spec): The specification object. + marshaler (ConnectUnaryMarshaler): The marshaler to serialize data. + unmarshaler (ConnectUnaryUnmarshaler): The unmarshaler to deserialize data. + request_headers (Headers): The headers for the request. + response_headers (Headers): The headers for the response. + response_trailers (Headers, optional): The trailers for the response. + + """ + self.writer = writer + self.request = request + self._peer = peer + self._spec = spec + self.marshaler = marshaler + self.unmarshaler = unmarshaler + self._request_headers = request_headers + self._response_headers = response_headers + self._response_trailers = response_trailers if response_trailers is not None else Headers() + + def parse_timeout(self) -> float | None: + """Parse the timeout value.""" + try: + timeout = self.request.headers.get(CONNECT_HEADER_TIMEOUT) + if timeout is None: + return None + + timeout_ms = int(timeout) + except ValueError as e: + raise ConnectError(f"parse timeout: {str(e)}", Code.INVALID_ARGUMENT) from e + + return timeout_ms / 1000 + + @property + def spec(self) -> Spec: + """Return the specification object. + + Returns: + Spec: The specification object. + + """ + return self._spec + + @property + def peer(self) -> Peer: + """Return the peer associated with this instance. + + :return: The peer associated with this instance. + :rtype: Peer + """ + return self._peer + + async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: + """Receives and unmarshals a message into an object. + + Args: + message (Any): The message to be unmarshaled. + + Returns: + AsyncIterator[Any]: An async iterator yielding the unmarshaled object. + + """ + yield await self.unmarshaler.unmarshal(message) + + def receive(self, message: Any) -> AsyncIterator[Any]: + """Receives a message, unmarshals it, and returns the resulting object. + + Args: + message (Any): The message to be unmarshaled. + + Returns: + AsyncIterator[Any]: An async iterator yielding the unmarshaled object. + + """ + return self._receive_messages(message) + + @property + def request_headers(self) -> Headers: + """Retrieve the headers from the request. + + Returns: + Mapping[str, str]: A dictionary-like object containing the request headers. + + """ + return self._request_headers + + async def send(self, messages: AsyncIterable[Any]) -> None: + """Send message(s) by marshaling them into bytes. + + Args: + messages (AsyncIterable[Any]): The message(s) to be sent. For unary operations, + this should be an iterable with a single item. + + Returns: + None + + """ + self.merge_response_trailers() + + message = await ensure_single(messages) + + data = self.marshaler.marshal(message) + await self.writer.write(Response(data, HTTPStatus.OK, self.response_headers)) + + @property + def response_headers(self) -> Headers: + """Retrieve the response headers. + + Returns: + Any: The response headers. + + """ + return self._response_headers + + @property + def response_trailers(self) -> Headers: + """Handle response trailers. + + This method is intended to be overridden in subclasses to provide + specific functionality for processing response trailers. + + Returns: + Any: The processed response trailer data. + + """ + return self._response_trailers + + def get_http_method(self) -> HTTPMethod: + """Retrieve the HTTP method from the request. + + Returns: + HTTPMethod: The HTTP method from the request. + + """ + return HTTPMethod(self.request.method) + + async def send_error(self, error: ConnectError) -> None: + """Send an error response. + + This method updates the response headers with the error metadata, + sets the response trailers, converts the error code to an HTTP status code, + serializes the error to JSON, and writes the response. + + Args: + error (ConnectError): The error to be sent in the response. + + Returns: + None + + """ + if not error.wire_error: + self.response_headers.update(exclude_protocol_headers(error.metadata)) + + self.merge_response_trailers() + + status_code = connect_code_to_http(error.code) + self.response_headers[HEADER_CONTENT_TYPE] = CONNECT_UNARY_CONTENT_TYPE_JSON + + body = error_to_json_bytes(error) + + await self.writer.write(Response(content=body, headers=self.response_headers, status_code=status_code)) + + def merge_response_trailers(self) -> None: + """Merge response trailers into the response headers. + + This method iterates through the `_response_trailers` dictionary and adds + each trailer key-value pair to the `_response_headers` dictionary, + prefixing the trailer keys with `CONNECT_UNARY_TRAILER_PREFIX`. + + Returns: + None + + """ + for key, value in self._response_trailers.items(): + self._response_headers[CONNECT_UNARY_TRAILER_PREFIX + key] = value + + +class ConnectStreamingHandlerConn(StreamingHandlerConn): + """ConnectStreamingHandlerConn is a class that handles streaming connections for the Connect protocol. + + Attributes: + writer (ServerResponseWriter): The writer used to send responses. + request (Request): The incoming request object. + _peer (Peer): The peer associated with this connection. + _spec (Spec): The specification object. + marshaler (ConnectStreamingMarshaler): The marshaler used to serialize messages. + unmarshaler (ConnectStreamingUnmarshaler): The unmarshaler used to deserialize messages. + _request_headers (Headers): The headers from the request. + _response_headers (Headers): The headers for the response. + _response_trailers (Headers): The trailers for the response. + + """ + + writer: ServerResponseWriter + request: Request + _peer: Peer + _spec: Spec + marshaler: ConnectStreamingMarshaler + unmarshaler: ConnectStreamingUnmarshaler + _request_headers: Headers + _response_headers: Headers + _response_trailers: Headers + + def __init__( + self, + writer: ServerResponseWriter, + request: Request, + peer: Peer, + spec: Spec, + marshaler: ConnectStreamingMarshaler, + unmarshaler: ConnectStreamingUnmarshaler, + request_headers: Headers, + response_headers: Headers, + response_trailers: Headers | None = None, + ) -> None: + """Initialize the protocol connection. + + Args: + writer (ServerResponseWriter): The writer for server responses. + request (Request): The request object. + peer (Peer): The peer information. + spec (Spec): The specification details. + marshaler (ConnectStreamingMarshaler): The marshaler for streaming. + unmarshaler (ConnectStreamingUnmarshaler): The unmarshaler for streaming. + request_headers (Headers): The headers for the request. + response_headers (Headers): The headers for the response. + response_trailers (Headers, optional): The trailers for the response. Defaults to None. + + """ + self.writer = writer + self.request = request + self._peer = peer + self._spec = spec + self.marshaler = marshaler + self.unmarshaler = unmarshaler + self._request_headers = request_headers + self._response_headers = response_headers + self._response_trailers = response_trailers if response_trailers is not None else Headers() + + def parse_timeout(self) -> float | None: + """Parse the timeout value.""" + try: + timeout = self.request.headers.get(CONNECT_HEADER_TIMEOUT) + if timeout is None: + return None + + timeout_ms = int(timeout) + except ValueError as e: + raise ConnectError(f"parse timeout: {str(e)}", Code.INVALID_ARGUMENT) from e + + return timeout_ms / 1000 + + @property + def spec(self) -> Spec: + """Return the specification object. + + Returns: + Spec: The specification object. + + """ + return self._spec + + @property + def peer(self) -> Peer: + """Return the peer associated with this instance. + + :return: The peer associated with this instance. + :rtype: Peer + """ + return self._peer + + async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: + """Asynchronously receives a message and yields unmarshaled objects. + + This method unmarshals the received message and yields each + unmarshaled object one by one as an asynchronous iterator. + + Args: + message (Any): The message to unmarshal. + + Returns: + AsyncIterator[Any]: An asynchronous iterator yielding unmarshaled objects. + + Yields: + Any: Each unmarshaled object from the message. + + """ + async for obj, _ in self.unmarshaler.unmarshal(message): + yield obj + + def receive(self, message: Any) -> AsyncIterator[Any]: + """Receives a message and returns an asynchronous content stream. + + This method processes the incoming message through the receive_message method + and wraps the result in an AsyncContentStream with the appropriate stream type. + + Args: + message (Any): The message to be processed. + + Returns: + AsyncContentStream[Any]: An asynchronous stream of content based on the + processed message, configured with the specification's stream type. + + """ + return self._receive_messages(message) + + @property + def request_headers(self) -> Headers: + """Retrieve the headers from the request. + + Returns: + Mapping[str, str]: A dictionary-like object containing the request headers. + + """ + return self._request_headers + + async def _send_messages(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: + """Create an async iterator that marshals messages with error handling. + + Args: + messages (AsyncIterable[Any]): Messages to marshal + + Returns: + AsyncIterator[bytes]: Marshaled bytes with end stream message + + Yields: + bytes: Each marshaled message followed by an end stream message + + """ + error: ConnectError | None = None + try: + async for message in self.marshaler.marshal(messages): + yield message + except Exception as e: + error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) + finally: + body = self.marshaler.marshal_end_stream(error, self.response_trailers) + yield body + + async def send(self, messages: AsyncIterable[Any]) -> None: + """Send a stream of messages asynchronously. + + This method marshals the provided messages and sends them using the writer. + If an error occurs during the marshaling process, it captures the error, + converts it to a JSON object, and sends it as the final message in the stream. + + Args: + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be sent. + + Returns: + None + + Raises: + ConnectError: If an error occurs during the marshaling process. + + """ + await self.writer.write( + StreamingResponse( + content=self._send_messages(messages), + headers=self.response_headers, + status_code=200, + ) + ) + + @property + def response_headers(self) -> Headers: + """Retrieve the response headers. + + Returns: + Any: The response headers. + + """ + return self._response_headers + + @property + def response_trailers(self) -> Headers: + """Handle response trailers. + + This method is intended to be overridden in subclasses to provide + specific functionality for processing response trailers. + + Returns: + Any: The processed response trailer data. + + """ + return self._response_trailers + + async def send_error(self, error: ConnectError) -> None: + """Send an error response in the form of a JSON object. + + Args: + error (ConnectError): The error object to be sent. + + Returns: + None + + """ + body = self.marshaler.marshal_end_stream(error, self.response_trailers) + + await self.writer.write( + StreamingResponse(content=aiterate([body]), headers=self.response_headers, status_code=200) + ) + + +def connect_check_protocol_version(request: Request, required: bool) -> ConnectError | None: + """Check the protocol version in the request headers for POST requests. + + Args: + request (Request): The incoming HTTP request. + required (bool): Flag indicating whether the protocol version is required. + + Raises: + ValueError: If the protocol version is required but not present in the headers. + ValueError: If the protocol version is present but unsupported. + ValueError: If the HTTP method is unsupported. + + """ + match HTTPMethod(request.method): + case HTTPMethod.GET: + version = request.query_params.get(CONNECT_UNARY_CONNECT_QUERY_PARAMETER) + if required and version is None: + return ConnectError( + f'missing required parameter: set {CONNECT_UNARY_CONNECT_QUERY_PARAMETER} to "{CONNECT_UNARY_CONNECT_QUERY_VALUE}"' + ) + elif version is not None and version != CONNECT_UNARY_CONNECT_QUERY_VALUE: + return ConnectError( + f'{CONNECT_UNARY_CONNECT_QUERY_PARAMETER} must be "{CONNECT_UNARY_CONNECT_QUERY_VALUE}": get "{version}"', + ) + case HTTPMethod.POST: + version = request.headers.get(CONNECT_HEADER_PROTOCOL_VERSION, None) + if required and version is None: + return ConnectError( + f'missing required header: set {CONNECT_HEADER_PROTOCOL_VERSION} to "{CONNECT_PROTOCOL_VERSION}"', + Code.INVALID_ARGUMENT, + ) + elif version is not None and version != CONNECT_PROTOCOL_VERSION: + return ConnectError( + f'{CONNECT_HEADER_PROTOCOL_VERSION} must be "{CONNECT_PROTOCOL_VERSION}": get "{version}"', + Code.INVALID_ARGUMENT, + ) + case _: + return ConnectError(f"unsupported method: {request.method}", Code.INVALID_ARGUMENT) + + return None + + +def error_to_json_bytes(error: ConnectError) -> bytes: + """Serialize a ConnectError object to a JSON-encoded byte string. + + Args: + error (ConnectError): The ConnectError object to serialize. + + Returns: + bytes: The JSON-encoded byte string representation of the error. + + Raises: + ConnectError: If serialization fails, a ConnectError is raised with an + appropriate error message and code. + + """ + try: + json_obj = error_to_json(error) + json_str = json.dumps(json_obj) + + return json_str.encode() + except Exception as e: + raise ConnectError(f"failed to serialize Connect Error: {e}", Code.INTERNAL) from e diff --git a/src/connect/protocol_connect/connect_protocol.py b/src/connect/protocol_connect/connect_protocol.py new file mode 100644 index 0000000..9a7964e --- /dev/null +++ b/src/connect/protocol_connect/connect_protocol.py @@ -0,0 +1,59 @@ +from http import HTTPMethod + +from connect.connect import ( + StreamType, +) +from connect.idempotency_level import IdempotencyLevel +from connect.protocol import ( + Protocol, + ProtocolClient, + ProtocolClientParams, + ProtocolHandlerParams, +) +from connect.protocol_connect.connect_client import ConnectClient +from connect.protocol_connect.connect_handler import ConnectHandler +from connect.protocol_connect.constants import ( + CONNECT_STREAMING_CONTENT_TYPE_PREFIX, + CONNECT_UNARY_CONTENT_TYPE_PREFIX, +) + + +class ProtocolConnect(Protocol): + """ProtocolConnect is a class that implements the Protocol interface for handling connection protocols.""" + + def handler(self, params: ProtocolHandlerParams) -> ConnectHandler: + """Handle the creation of a ConnectHandler based on the provided ProtocolHandlerParams. + + Args: + params (ProtocolHandlerParams): The parameters required to create the ConnectHandler. + + Returns: + ConnectHandler: An instance of ConnectHandler configured with the appropriate methods and content types. + + """ + methods = [HTTPMethod.POST] + + if params.spec.stream_type == StreamType.Unary and params.idempotency_level == IdempotencyLevel.NO_SIDE_EFFECTS: + methods.append(HTTPMethod.GET) + + content_types: list[str] = [] + for name in params.codecs.names(): + if params.spec.stream_type == StreamType.Unary: + content_types.append(CONNECT_UNARY_CONTENT_TYPE_PREFIX + name) + continue + + content_types.append(CONNECT_STREAMING_CONTENT_TYPE_PREFIX + name) + + return ConnectHandler(params, methods=methods, accept=content_types) + + def client(self, params: ProtocolClientParams) -> ProtocolClient: + """Create and returns a ConnectClient instance. + + Args: + params (ProtocolClientParams): The parameters required to initialize the client. + + Returns: + ProtocolClient: An instance of ConnectClient. + + """ + return ConnectClient(params) diff --git a/src/connect/protocol_connect/constants.py b/src/connect/protocol_connect/constants.py new file mode 100644 index 0000000..6f9e530 --- /dev/null +++ b/src/connect/protocol_connect/constants.py @@ -0,0 +1,25 @@ +from sys import version + +from connect.version import __version__ + +CONNECT_UNARY_HEADER_COMPRESSION = "Content-Encoding" +CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION = "Accept-Encoding" +CONNECT_UNARY_TRAILER_PREFIX = "Trailer-" +CONNECT_STREAMING_HEADER_COMPRESSION = "Connect-Content-Encoding" +CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION = "Connect-Accept-Encoding" +CONNECT_HEADER_TIMEOUT = "Connect-Timeout-Ms" +CONNECT_HEADER_PROTOCOL_VERSION = "Connect-Protocol-Version" +CONNECT_PROTOCOL_VERSION = "1" + +CONNECT_UNARY_CONTENT_TYPE_PREFIX = "application/" +CONNECT_UNARY_CONTENT_TYPE_JSON = "application/json" +CONNECT_STREAMING_CONTENT_TYPE_PREFIX = "application/connect+" + +CONNECT_UNARY_ENCODING_QUERY_PARAMETER = "encoding" +CONNECT_UNARY_MESSAGE_QUERY_PARAMETER = "message" +CONNECT_UNARY_BASE64_QUERY_PARAMETER = "base64" +CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER = "compression" +CONNECT_UNARY_CONNECT_QUERY_PARAMETER = "connect" +CONNECT_UNARY_CONNECT_QUERY_VALUE = "v" + CONNECT_PROTOCOL_VERSION + +DEFAULT_CONNECT_USER_AGENT = f"connect-python/{__version__} (Python/{version})" diff --git a/src/connect/protocol_connect/content_type.py b/src/connect/protocol_connect/content_type.py new file mode 100644 index 0000000..98ad9b1 --- /dev/null +++ b/src/connect/protocol_connect/content_type.py @@ -0,0 +1,103 @@ +from http import HTTPStatus + +from connect.code import Code +from connect.codec import CodecNameType +from connect.connect import ( + StreamType, +) +from connect.error import ConnectError +from connect.protocol import ( + code_from_http_status, +) +from connect.protocol_connect.constants import ( + CONNECT_STREAMING_CONTENT_TYPE_PREFIX, + CONNECT_UNARY_CONTENT_TYPE_PREFIX, +) + + +def connect_codec_from_content_type(stream_type: StreamType, content_type: str) -> str: + """Extract the codec from the content type based on the stream type. + + Args: + stream_type (StreamType): The type of stream (Unary or Streaming). + content_type (str): The content type string from which to extract the codec. + + Returns: + str: The extracted codec from the content type. + + """ + if stream_type == StreamType.Unary: + return content_type[len(CONNECT_UNARY_CONTENT_TYPE_PREFIX) :] + + return content_type[len(CONNECT_STREAMING_CONTENT_TYPE_PREFIX) :] + + +def connect_content_type_from_codec_name(stream_type: StreamType, codec_name: str) -> str: + """Generate the content type string for a given stream type and codec name. + + Args: + stream_type (StreamType): The type of the stream (e.g., Unary or Streaming). + codec_name (str): The name of the codec. + + Returns: + str: The content type string constructed from the stream type and codec name. + + """ + if stream_type == StreamType.Unary: + return CONNECT_UNARY_CONTENT_TYPE_PREFIX + codec_name + + return CONNECT_STREAMING_CONTENT_TYPE_PREFIX + codec_name + + +def connect_validate_unary_response_content_type( + request_codec_name: str, + status_code: int, + response_content_type: str, +) -> ConnectError | None: + """Validate the content type of a unary response based on the HTTP status code and method. + + Args: + request_codec_name (str): The name of the codec used for the request. + http_method (HTTPMethod): The HTTP method used for the request. + status_code (int): The HTTP status code of the response. + response_content_type (str): The content type of the response. + + Raises: + ConnectError: If the status code is not OK and the response content type is not valid. + + """ + if status_code != HTTPStatus.OK: + # Error response must be JSON-encoded. + if ( + response_content_type == CONNECT_UNARY_CONTENT_TYPE_PREFIX + CodecNameType.JSON + or response_content_type == CONNECT_UNARY_CONTENT_TYPE_PREFIX + CodecNameType.JSON_CHARSET_UTF8 + ): + return ConnectError( + f"HTTP {status_code}", + code_from_http_status(status_code), + ) + + raise ConnectError( + f"HTTP {status_code}", + code_from_http_status(status_code), + ) + + if not response_content_type.startswith(CONNECT_UNARY_CONTENT_TYPE_PREFIX): + raise ConnectError( + f"invalid content-type: {response_content_type}; expecting {CONNECT_UNARY_CONTENT_TYPE_PREFIX}", + Code.UNKNOWN, + ) + + response_codec_name = connect_codec_from_content_type(StreamType.Unary, response_content_type) + if response_codec_name == request_codec_name: + return None + + if (response_codec_name == CodecNameType.JSON and request_codec_name == CodecNameType.JSON_CHARSET_UTF8) or ( + response_codec_name == CodecNameType.JSON_CHARSET_UTF8 and request_codec_name == CodecNameType.JSON + ): + return None + + raise ConnectError( + f"invalid content-type: {response_content_type}; expecting {CONNECT_UNARY_CONTENT_TYPE_PREFIX}{request_codec_name}", + Code.INTERNAL, + ) diff --git a/src/connect/protocol_connect/end_stream.py b/src/connect/protocol_connect/end_stream.py new file mode 100644 index 0000000..96caf32 --- /dev/null +++ b/src/connect/protocol_connect/end_stream.py @@ -0,0 +1,75 @@ +import json +from typing import Any + +from connect.code import Code +from connect.error import ConnectError +from connect.headers import Headers +from connect.protocol_connect.error_json import error_from_json, error_to_json + + +def end_stream_to_json(error: ConnectError | None, trailers: Headers) -> dict[str, Any]: + """Convert the end of a stream to a JSON-serializable dictionary. + + Args: + error (ConnectError | None): An optional error object that may contain metadata. + trailers (Headers): Headers object containing metadata. + + Returns: + dict[str, Any]: A dictionary containing the error and metadata information in JSON-serializable format. + + """ + json_obj = {} + + metadata = Headers(trailers.copy()) + if error: + json_obj["error"] = error_to_json(error) + metadata.update(error.metadata.copy()) + + if len(metadata) > 0: + json_obj["metadata"] = {k: v.split(", ") for k, v in metadata.items()} + + return json_obj + + +def end_stream_from_bytes(data: bytes) -> tuple[ConnectError | None, Headers]: + """Parse a byte stream to extract metadata and error information. + + Args: + data (bytes): The byte stream to be parsed. + + Returns: + tuple[ConnectError | None, Headers]: A tuple containing an optional ConnectError + and a Headers object with the parsed metadata. + + Raises: + ConnectError: If the byte stream is invalid or the metadata format is incorrect. + + """ + parse_error = ConnectError("invalid end stream", Code.UNKNOWN) + try: + obj = json.loads(data) + except Exception as e: + raise ConnectError( + "invalid end stream", + Code.UNKNOWN, + ) from e + + metadata = Headers() + if "metadata" in obj: + if not isinstance(obj["metadata"], dict) or not all( + isinstance(k, str) and isinstance(v, list) for k, v in obj["metadata"].items() + ): + raise ConnectError( + "invalid end stream", + Code.UNKNOWN, + ) + + for key, values in obj["metadata"].items(): + value = ", ".join(values) + metadata[key] = value + + if "error" in obj and obj["error"] is not None: + error = error_from_json(obj["error"], parse_error) + return error, metadata + else: + return None, metadata diff --git a/src/connect/protocol_connect/error_code.py b/src/connect/protocol_connect/error_code.py new file mode 100644 index 0000000..1291e60 --- /dev/null +++ b/src/connect/protocol_connect/error_code.py @@ -0,0 +1,67 @@ +from connect.code import Code + + +def connect_code_to_http(code: Code) -> int: + """Convert a given `Code` enumeration to its corresponding HTTP status code. + + Args: + code (Code): The `Code` enumeration value to be converted. + + Returns: + int: The corresponding HTTP status code. + + The mapping is as follows: + - Code.CANCELED -> 499 + - Code.UNKNOWN -> 500 + - Code.INVALID_ARGUMENT -> 400 + - Code.DEADLINE_EXCEEDED -> 504 + - Code.NOT_FOUND -> 404 + - Code.ALREADY_EXISTS -> 409 + - Code.PERMISSION_DENIED -> 403 + - Code.RESOURCE_EXHAUSTED -> 429 + - Code.FAILED_PRECONDITION -> 400 + - Code.ABORTED -> 409 + - Code.OUT_OF_RANGE -> 400 + - Code.UNIMPLEMENTED -> 501 + - Code.INTERNAL -> 500 + - Code.UNAVAILABLE -> 503 + - Code.DATA_LOSS -> 500 + - Code.UNAUTHENTICATED -> 401 + - Any other code -> 500 + + """ + match code: + case Code.CANCELED: + return 499 + case Code.UNKNOWN: + return 500 + case Code.INVALID_ARGUMENT: + return 400 + case Code.DEADLINE_EXCEEDED: + return 504 + case Code.NOT_FOUND: + return 404 + case Code.ALREADY_EXISTS: + return 409 + case Code.PERMISSION_DENIED: + return 403 + case Code.RESOURCE_EXHAUSTED: + return 429 + case Code.FAILED_PRECONDITION: + return 400 + case Code.ABORTED: + return 409 + case Code.OUT_OF_RANGE: + return 400 + case Code.UNIMPLEMENTED: + return 501 + case Code.INTERNAL: + return 500 + case Code.UNAVAILABLE: + return 503 + case Code.DATA_LOSS: + return 500 + case Code.UNAUTHENTICATED: + return 401 + case _: + return 500 diff --git a/src/connect/protocol_connect/error_json.py b/src/connect/protocol_connect/error_json.py new file mode 100644 index 0000000..6830894 --- /dev/null +++ b/src/connect/protocol_connect/error_json.py @@ -0,0 +1,143 @@ +import base64 +import contextlib +import json +from typing import Any + +import google.protobuf.any_pb2 as any_pb2 +from google.protobuf import json_format + +from connect.code import Code +from connect.error import DEFAULT_ANY_RESOLVER_PREFIX, ConnectError, ErrorDetail + +_string_to_code: dict[str, Code] | None = None + + +def code_to_string(value: Code) -> str: + """Convert a Code object to its string representation. + + If the Code object has a 'name' attribute and it is not None, the method returns + the lowercase version of the 'name'. Otherwise, it returns the string representation + of the 'value' attribute. + + Args: + value (Code): The Code object to be converted to a string. + + Returns: + str: The string representation of the Code object. + + """ + if not hasattr(value, "name") or value.name is None: + return str(value.value) + + return value.name.lower() + + +def code_from_string(value: str) -> Code | None: + """Convert a string representation of a code to its corresponding Code enum value. + + This function uses a global dictionary to cache the mapping from string to Code enum values. + If the cache is not initialized, it populates the cache by iterating over all Code enum values + and mapping their string representations to the corresponding Code enum. + + Args: + value (str): The string representation of the code. + + Returns: + Code | None: The corresponding Code enum value if found, otherwise None. + + """ + global _string_to_code + + if _string_to_code is None: + _string_to_code = {} + for code in Code: + _string_to_code[code_to_string(code)] = code + + return _string_to_code.get(value) + + +def error_from_json(obj: dict[str, Any], fallback: ConnectError) -> ConnectError: + """Convert a JSON-serializable dictionary to a ConnectError object. + + Args: + obj (dict[str, Any]): The dictionary representing the error in JSON format. + fallback (ConnectError): A fallback ConnectError object to use in case of missing or invalid fields. + + Returns: + ConnectError: The ConnectError object converted from the dictionary. + + Raises: + ConnectError: If the dictionary is missing required fields or contains invalid values, + a ConnectError is raised with an appropriate error message and code. + + """ + code = fallback.code + if "code" in obj: + code = code_from_string(obj["code"]) or code + + message = obj.get("message", "") + details = obj.get("details", []) + + error = ConnectError(message, code, wire_error=True) + + for detail in details: + type_name = detail.get("type", None) + value = detail.get("value", None) + + if type_name is None: + raise fallback + if value is None: + raise fallback + + type_name = type_name if "/" in type_name else DEFAULT_ANY_RESOLVER_PREFIX + type_name + try: + decoded = base64.b64decode(value.encode() + b"=" * (4 - len(value) % 4)) + except Exception as e: + raise fallback from e + + error.details.append( + ErrorDetail(pb_any=any_pb2.Any(type_url=type_name, value=decoded), wire_json=json.dumps(detail)) + ) + + return error + + +def error_to_json(error: ConnectError) -> dict[str, Any]: + """Convert a ConnectError object to a JSON-serializable dictionary. + + Args: + error (ConnectError): The error object to convert. + + Returns: + dict[str, Any]: A dictionary representing the error in JSON format. + - "code" (str): The error code as a string. + - "message" (str, optional): The raw error message, if available. + - "details" (list[dict[str, Any]], optional): A list of dictionaries containing error details, if available. + Each detail dictionary contains: + - "type" (str): The type name of the detail. + - "value" (str): The base64-encoded value of the detail. + - "debug" (str, optional): The JSON-encoded debug information, if available. + + """ + obj: dict[str, Any] = {"code": error.code.string()} + + if len(error.raw_message) > 0: + obj["message"] = error.raw_message + + if len(error.details) > 0: + wires = [] + for detail in error.details: + wire: dict[str, Any] = { + "type": detail.pb_any.TypeName(), + "value": base64.b64encode(detail.pb_any.value).decode().rstrip("="), + } + + with contextlib.suppress(Exception): + meg = detail.get_inner() + wire["debug"] = json_format.MessageToDict(meg) + + wires.append(wire) + + obj["details"] = wires + + return obj diff --git a/src/connect/protocol_connect/marshaler.py b/src/connect/protocol_connect/marshaler.py new file mode 100644 index 0000000..342e4ae --- /dev/null +++ b/src/connect/protocol_connect/marshaler.py @@ -0,0 +1,351 @@ +import base64 +import contextlib +import json +from typing import Any + +from yarl import URL + +from connect.code import Code +from connect.codec import Codec, StableCodec +from connect.compression import Compression +from connect.envelope import EnvelopeFlags, EnvelopeWriter +from connect.error import ConnectError +from connect.headers import Headers +from connect.protocol import ( + HEADER_CONTENT_ENCODING, + HEADER_CONTENT_LENGTH, + HEADER_CONTENT_TYPE, +) +from connect.protocol_connect.constants import ( + CONNECT_HEADER_PROTOCOL_VERSION, + CONNECT_UNARY_BASE64_QUERY_PARAMETER, + CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER, + CONNECT_UNARY_CONNECT_QUERY_PARAMETER, + CONNECT_UNARY_CONNECT_QUERY_VALUE, + CONNECT_UNARY_ENCODING_QUERY_PARAMETER, + CONNECT_UNARY_HEADER_COMPRESSION, + CONNECT_UNARY_MESSAGE_QUERY_PARAMETER, +) +from connect.protocol_connect.end_stream import end_stream_to_json + + +class ConnectUnaryMarshaler: + """ConnectUnaryMarshaler is responsible for serializing and optionally compressing messages. + + Attributes: + codec (Codec): The codec used for serializing messages. + compression (Compression | None): The compression method used for compressing messages, if any. + compress_min_bytes (int): The minimum size in bytes for a message to be compressed. + send_max_bytes (int): The maximum allowed size in bytes for a message to be sent. + headers (Headers | Headers): The headers to be included in the message. + + """ + + codec: Codec | None + compression: Compression | None + compress_min_bytes: int + send_max_bytes: int + headers: Headers + + def __init__( + self, + codec: Codec | None, + compression: Compression | None, + compress_min_bytes: int, + send_max_bytes: int, + headers: Headers, + ) -> None: + """Initialize the protocol connection. + + Args: + codec (Codec): The codec to be used for encoding/decoding. + compression (Compression | None): The compression method to be used, or None if no compression. + compress_min_bytes (int): The minimum number of bytes before compression is applied. + send_max_bytes (int): The maximum number of bytes to send in a single message. + headers (Headers): The headers to be included in the connection. + + Returns: + None + + """ + self.codec = codec + self.compression = compression + self.compress_min_bytes = compress_min_bytes + self.send_max_bytes = send_max_bytes + self.headers = headers + + def marshal(self, message: Any) -> bytes: + """Marshals a message into bytes, optionally compressing it if it exceeds a certain size. + + Args: + message (Any): The message to be marshaled. + + Returns: + bytes: The marshaled (and possibly compressed) message. + + Raises: + ConnectError: If there is an error during marshaling or if the message size exceeds the allowed limit. + + """ + if self.codec is None: + raise ConnectError("codec is not set", Code.INTERNAL) + + try: + data = self.codec.marshal(message) + except Exception as e: + raise ConnectError(f"marshal message: {str(e)}", Code.INTERNAL) from e + + if len(data) < self.compress_min_bytes or self.compression is None: + if self.send_max_bytes > 0 and len(data) > self.send_max_bytes: + raise ConnectError( + f"message size {len(data)} exceeds send_max_bytes {self.send_max_bytes}", Code.RESOURCE_EXHAUSTED + ) + + return data + + data = self.compression.compress(data) + + if self.send_max_bytes > 0 and len(data) > self.send_max_bytes: + raise ConnectError( + f"compressed message size {len(data)} exceeds send_max_bytes {self.send_max_bytes}", + Code.RESOURCE_EXHAUSTED, + ) + + self.headers[CONNECT_UNARY_HEADER_COMPRESSION] = self.compression.name + + return data + + +class ConnectUnaryRequestMarshaler(ConnectUnaryMarshaler): + """ConnectUnaryRequestMarshaler is responsible for marshaling unary request messages for the Connect protocol, with support for GET requests and stable codecs. + + This class extends ConnectUnaryMarshaler to provide additional functionality for handling GET requests, + including marshaling messages using a stable codec, enforcing message size limits, and optionally compressing + messages when necessary. It also manages the construction of GET URLs with appropriate query parameters and + headers for the Connect protocol. + + Attributes: + enable_get (bool): Flag indicating whether GET requests are enabled. + stable_codec (StableCodec | None): The codec used for stable marshaling, if available. + url (URL | None): The URL to use for the request. + + """ + + enable_get: bool + stable_codec: StableCodec | None + url: URL | None + + def __init__( + self, + codec: Codec | None, + compression: Compression | None, + compress_min_bytes: int, + send_max_bytes: int, + headers: Headers, + enable_get: bool = False, + stable_codec: StableCodec | None = None, + url: URL | None = None, + ) -> None: + """Initialize the protocol connection with the specified configuration. + + Args: + codec (Codec | None): The codec to use for encoding/decoding messages, or None. + compression (Compression | None): The compression algorithm to use, or None. + compress_min_bytes (int): Minimum number of bytes before compression is applied. + send_max_bytes (int): Maximum number of bytes allowed per send operation. + headers (Headers): Headers to include in each request. + enable_get (bool, optional): Whether to enable GET requests. Defaults to False. + stable_codec (StableCodec | None, optional): An optional stable codec for message encoding/decoding. Defaults to None. + url (URL | None, optional): The URL endpoint for the connection. Defaults to None. + + Returns: + None + + """ + super().__init__(codec, compression, compress_min_bytes, send_max_bytes, headers) + self.enable_get = enable_get + self.stable_codec = stable_codec + self.url = url + + def marshal(self, message: Any) -> bytes: + """Marshal a message into bytes. + + If `enable_get` is True and `stable_codec` is None, raises a `ConnectError` + indicating that the codec does not support stable marshal and cannot use get. + Otherwise, if `enable_get` is True and `stable_codec` is not None, marshals + the message using the `marshal_with_get` method. + + If `enable_get` is False, marshals the message using the `. + + Args: + message (Any): The message to be marshaled. + + Returns: + bytes: The marshaled message in bytes. + + Raises: + ConnectError: If `enable_get` is True and `stable_codec` is None. + + """ + if self.enable_get: + if self.codec is None: + raise ConnectError("codec is not set", Code.INTERNAL) + + if self.stable_codec is None: + raise ConnectError( + f"codec {self.codec.name} doesn't support stable marshal; can't use get", + Code.INTERNAL, + ) + else: + return self.marshal_with_get(message) + + return super().marshal(message) + + def marshal_with_get(self, message: Any) -> bytes: + """Marshals the given message and sends it using a GET request. + + This method first marshals the message using the stable codec. If the marshaled + data exceeds the maximum allowed size (`send_max_bytes`) and compression is not + enabled, it raises a `ConnectError`. If the data size is within the limit, it + builds the GET URL and sends the data. + + If the data size exceeds the limit and compression is enabled, it compresses + the data and checks the size again. If the compressed data still exceeds the + limit, it raises a `ConnectError`. Otherwise, it builds the GET URL with the + compressed data and sends it. + + Args: + message (Any): The message to be marshaled and sent. + + Returns: + bytes: The marshaled (and possibly compressed) data. + + Raises: + ConnectError: If the data size exceeds the maximum allowed size and compression + is not enabled, or if the compressed data size still exceeds the + limit. + + """ + if self.stable_codec is None: + raise ConnectError("stable_codec is not set", Code.INTERNAL) + + data = self.stable_codec.marshal_stable(message) + + is_too_big = self.send_max_bytes > 0 and len(data) > self.send_max_bytes + if is_too_big and not self.compression: + raise ConnectError( + f"message size {len(data)} exceeds sendMaxBytes {self.send_max_bytes}: enabling request compression may help", + Code.RESOURCE_EXHAUSTED, + ) + + if not is_too_big: + url = self._build_get_url(data, False) + + self._write_with_get(url) + return data + + if self.compression: + data = self.compression.compress(data) + + if self.send_max_bytes > 0 and len(data) > self.send_max_bytes: + raise ConnectError( + f"compressed message size {len(data)} exceeds send_max_bytes {self.send_max_bytes}", + Code.RESOURCE_EXHAUSTED, + ) + + url = self._build_get_url(data, True) + self._write_with_get(url) + + return data + + def _build_get_url(self, data: bytes, compressed: bool) -> URL: + if self.url is None or self.stable_codec is None: + raise ConnectError("url or stable_codec is not set", Code.INTERNAL) + + if self.codec is None: + raise ConnectError("codec is not set", Code.INTERNAL) + + url = self.url + url = url.update_query({ + CONNECT_UNARY_CONNECT_QUERY_PARAMETER: CONNECT_UNARY_CONNECT_QUERY_VALUE, + CONNECT_UNARY_ENCODING_QUERY_PARAMETER: self.codec.name, + }) + if self.stable_codec.is_binary() or compressed: + url = url.update_query({ + CONNECT_UNARY_MESSAGE_QUERY_PARAMETER: base64.urlsafe_b64encode(data).rstrip(b"=").decode("utf-8"), + CONNECT_UNARY_BASE64_QUERY_PARAMETER: "1", + }) + else: + url = url.update_query({ + CONNECT_UNARY_MESSAGE_QUERY_PARAMETER: data.decode("utf-8"), + }) + + if compressed: + if not self.compression: + raise ConnectError( + "compression must be set for compressed message", + Code.INTERNAL, + ) + + url = url.update_query({CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER: self.compression.name}) + + return url + + def _write_with_get(self, url: URL) -> None: + with contextlib.suppress(Exception): + del self.headers[CONNECT_HEADER_PROTOCOL_VERSION] + del self.headers[HEADER_CONTENT_TYPE] + del self.headers[HEADER_CONTENT_ENCODING] + del self.headers[HEADER_CONTENT_LENGTH] + + self.url = url + + +class ConnectStreamingMarshaler(EnvelopeWriter): + """A class responsible for marshaling messages with optional compression. + + Attributes: + codec (Codec): The codec used for marshaling messages. + compression (Compression | None): The compression method used for compressing messages, if any. + + """ + + codec: Codec | None + compress_min_bytes: int + send_max_bytes: int + compression: Compression | None + + def __init__( + self, codec: Codec | None, compression: Compression | None, compress_min_bytes: int, send_max_bytes: int + ) -> None: + """Initialize the ProtocolConnect instance. + + Args: + codec (Codec): The codec to be used for encoding and decoding. + compression (Compression | None): The compression method to be used, or None if no compression is to be applied. + compress_min_bytes (int): The minimum number of bytes before compression is applied. + send_max_bytes (int): The maximum number of bytes that can be sent in a single message. + + """ + self.codec = codec + self.compress_min_bytes = compress_min_bytes + self.send_max_bytes = send_max_bytes + self.compression = compression + + def marshal_end_stream(self, error: ConnectError | None, response_trailers: Headers) -> bytes: + """Serialize the end-of-stream message with optional error and response trailers into a bytes envelope. + + Args: + error (ConnectError | None): An optional error object to include in the end-of-stream message. + response_trailers (Headers): Headers to include as response trailers. + + Returns: + bytes: The serialized envelope containing the end-of-stream message. + + """ + json_obj = end_stream_to_json(error, response_trailers) + json_str = json.dumps(json_obj) + + env = self.write_envelope(json_str.encode(), EnvelopeFlags.end_stream) + + return env.encode() diff --git a/src/connect/protocol_connect/unmarshaler.py b/src/connect/protocol_connect/unmarshaler.py new file mode 100644 index 0000000..5fd5615 --- /dev/null +++ b/src/connect/protocol_connect/unmarshaler.py @@ -0,0 +1,214 @@ +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Callable, +) +from typing import Any + +from connect.code import Code +from connect.codec import Codec +from connect.compression import Compression +from connect.envelope import EnvelopeReader +from connect.error import ConnectError +from connect.headers import Headers +from connect.protocol_connect.end_stream import end_stream_from_bytes +from connect.utils import get_acallable_attribute + + +class ConnectUnaryUnmarshaler: + """A class to handle the unmarshaling of data using a specified codec. + + Attributes: + codec (Codec): The codec used for unmarshaling the data. + body (bytes): The raw data to be unmarshaled. + read_max_bytes (int): The maximum number of bytes to read. + compression (Compression | None): The compression method to use, if any. + + """ + + codec: Codec | None + read_max_bytes: int + compression: Compression | None + stream: AsyncIterable[bytes] | None + + def __init__( + self, + codec: Codec | None, + read_max_bytes: int, + compression: Compression | None = None, + stream: AsyncIterable[bytes] | None = None, + ) -> None: + """Initialize the ProtocolConnect object. + + Args: + stream (AsyncIterable[bytes] | None): The stream of bytes to be unmarshaled. + codec (Codec): The codec used for encoding/decoding the message. + read_max_bytes (int): The maximum number of bytes to read. + compression (Compression | None): The compression method to use, if any. + + """ + self.codec = codec + self.read_max_bytes = read_max_bytes + self.compression = compression + self.stream = stream + + async def unmarshal(self, message: Any) -> Any: + """Asynchronously unmarshals a given message using the provided unmarshal function and codec. + + Args: + message (Any): The message to be unmarshaled. + + Returns: + Any: The result of the unmarshaling process. + + """ + if self.codec is None: + raise ConnectError("codec is not set", Code.INTERNAL) + + return await self.unmarshal_func(message, self.codec.unmarshal) + + async def unmarshal_func(self, message: Any, func: Callable[[bytes, Any], Any]) -> Any: + """Asynchronously unmarshals a message using the provided function. + + This function reads data from the stream in chunks, checks if the total + bytes read exceed the maximum allowed bytes, and optionally decompresses + the data. It then uses the provided function to unmarshal the data into + the desired format. + + Args: + message (Any): The message to be unmarshaled. + func (Callable[[bytes, Any], Any]): A function that takes the raw bytes + and the message, and returns the unmarshaled object. + + Returns: + Any: The unmarshaled object. + + Raises: + ConnectError: If the stream is not set, if the message size exceeds the + maximum allowed bytes, or if there is an error during unmarshaling. + + """ + if self.stream is None: + raise ConnectError("stream is not set", Code.INTERNAL) + + chunks: list[bytes] = [] + bytes_read = 0 + try: + async for chunk in self.stream: + chunk_size = len(chunk) + bytes_read += chunk_size + if self.read_max_bytes > 0 and bytes_read > self.read_max_bytes: + raise ConnectError( + f"message size {bytes_read} is larger than configured max {self.read_max_bytes}", + Code.RESOURCE_EXHAUSTED, + ) + + chunks.append(chunk) + + data = b"".join(chunks) + + if len(data) > 0 and self.compression: + data = self.compression.decompress(data, self.read_max_bytes) + + try: + obj = func(data, message) + except Exception as e: + raise ConnectError( + f"unmarshal message: {str(e)}", + Code.INVALID_ARGUMENT, + ) from e + finally: + await self.aclose() + + return obj + + async def aclose(self) -> None: + """Asynchronously close the stream if it is set. + + This method is intended to be called when the stream is no longer needed + to release any associated resources. + + """ + aclose = get_acallable_attribute(self.stream, "aclose") + if aclose: + await aclose() + + +class ConnectStreamingUnmarshaler(EnvelopeReader): + """A class to handle the unmarshaling of streaming data. + + Attributes: + codec (Codec): The codec used for unmarshaling data. + compression (Compression | None): The compression method used, if any. + stream (AsyncIterable[bytes] | None): The asynchronous byte stream to read from. + buffer (bytes): The buffer to store incoming data chunks. + + """ + + _end_stream_error: ConnectError | None + _trailers: Headers + + def __init__( + self, + codec: Codec | None, + read_max_bytes: int, + stream: AsyncIterable[bytes] | None = None, + compression: Compression | None = None, + ) -> None: + """Initialize the protocol connection. + + Args: + codec (Codec): The codec to use for encoding and decoding data. + read_max_bytes (int): The maximum number of bytes to read from the stream. + stream (AsyncIterable[bytes] | None, optional): The asynchronous byte stream to read from. Defaults to None. + compression (Compression | None, optional): The compression method to use. Defaults to None. + + """ + super().__init__(codec, read_max_bytes, stream, compression) + self._end_stream_error = None + self._trailers = Headers() + + async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: + """Asynchronously unmarshals messages from the stream. + + Args: + message (Any): The message type to unmarshal. + + Yields: + Any: The unmarshaled message object. + + Raises: + ConnectError: If the stream is not set, if there is an error in the + unmarshaling process, or if there is a protocol error. + + """ + async for obj, end in super().unmarshal(message): + if self.last: + error, trailers = end_stream_from_bytes(self.last.data) + self._end_stream_error = error + self._trailers = trailers + + yield obj, end + + @property + def trailers(self) -> Headers: + """Return the trailers headers. + + Trailers are additional headers sent after the body of the message. + + Returns: + Headers: The trailers headers. + + """ + return self._trailers + + @property + def end_stream_error(self) -> ConnectError | None: + """Return the error that occurred at the end of the stream, if any. + + Returns: + ConnectError | None: The error that occurred at the end of the stream, + or None if no error occurred. + + """ + return self._end_stream_error From 420af941662d80287032d321f8c87a3167f5a59d Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Tue, 13 May 2025 18:33:11 +0900 Subject: [PATCH 2/3] protocol_grpc: refactor protocol_connect file --- src/connect/client.py | 2 +- src/connect/handler.py | 2 +- .../protocol_connect/connect_client.py | 2 +- .../protocol_connect/connect_handler.py | 2 + .../protocol_connect/connect_protocol.py | 2 + src/connect/protocol_connect/constants.py | 2 + src/connect/protocol_connect/content_type.py | 2 + src/connect/protocol_connect/end_stream.py | 2 + src/connect/protocol_connect/error_code.py | 2 + src/connect/protocol_connect/error_json.py | 2 + src/connect/protocol_connect/marshaler.py | 2 + src/connect/protocol_connect/unmarshaler.py | 2 + src/connect/protocol_grpc.py | 1356 ----------------- src/connect/protocol_grpc/__init__.py | 0 src/connect/protocol_grpc/constants.py | 38 + src/connect/protocol_grpc/content_type.py | 89 ++ src/connect/protocol_grpc/error_trailer.py | 162 ++ src/connect/protocol_grpc/grpc_client.py | 478 ++++++ src/connect/protocol_grpc/grpc_handler.py | 437 ++++++ src/connect/protocol_grpc/grpc_protocol.py | 91 ++ src/connect/protocol_grpc/marshaler.py | 62 + src/connect/protocol_grpc/unmarshaler.py | 105 ++ 22 files changed, 1483 insertions(+), 1359 deletions(-) delete mode 100644 src/connect/protocol_grpc.py create mode 100644 src/connect/protocol_grpc/__init__.py create mode 100644 src/connect/protocol_grpc/constants.py create mode 100644 src/connect/protocol_grpc/content_type.py create mode 100644 src/connect/protocol_grpc/error_trailer.py create mode 100644 src/connect/protocol_grpc/grpc_client.py create mode 100644 src/connect/protocol_grpc/grpc_handler.py create mode 100644 src/connect/protocol_grpc/grpc_protocol.py create mode 100644 src/connect/protocol_grpc/marshaler.py create mode 100644 src/connect/protocol_grpc/unmarshaler.py diff --git a/src/connect/client.py b/src/connect/client.py index 3fa27e4..b773e9c 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -29,7 +29,7 @@ from connect.options import ClientOptions from connect.protocol import Protocol, ProtocolClient, ProtocolClientParams from connect.protocol_connect.connect_protocol import ProtocolConnect -from connect.protocol_grpc import ProtocolGRPC +from connect.protocol_grpc.grpc_protocol import ProtocolGRPC from connect.session import AsyncClientSession from connect.utils import aiterate diff --git a/src/connect/handler.py b/src/connect/handler.py index 8494ab6..2d65d88 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -40,7 +40,7 @@ from connect.protocol_connect.connect_protocol import ( ProtocolConnect, ) -from connect.protocol_grpc import ProtocolGRPC +from connect.protocol_grpc.grpc_protocol import ProtocolGRPC from connect.request import Request from connect.response import Response from connect.utils import aiterate diff --git a/src/connect/protocol_connect/connect_client.py b/src/connect/protocol_connect/connect_client.py index 57c076a..ad31d7e 100644 --- a/src/connect/protocol_connect/connect_client.py +++ b/src/connect/protocol_connect/connect_client.py @@ -1,4 +1,4 @@ -"""Provides classes and functions for handling protocol connections.""" +"""Provides a ConnectClient class for handling connections using the Connect protocol.""" import asyncio import contextlib diff --git a/src/connect/protocol_connect/connect_handler.py b/src/connect/protocol_connect/connect_handler.py index 28bca96..be1bcb6 100644 --- a/src/connect/protocol_connect/connect_handler.py +++ b/src/connect/protocol_connect/connect_handler.py @@ -1,3 +1,5 @@ +"""Provides a ConnectHander class for handling connection protocols.""" + import base64 import json from collections.abc import ( diff --git a/src/connect/protocol_connect/connect_protocol.py b/src/connect/protocol_connect/connect_protocol.py index 9a7964e..0aaefec 100644 --- a/src/connect/protocol_connect/connect_protocol.py +++ b/src/connect/protocol_connect/connect_protocol.py @@ -1,3 +1,5 @@ +"""Provides the ProtocolConnect class for handling connection protocols.""" + from http import HTTPMethod from connect.connect import ( diff --git a/src/connect/protocol_connect/constants.py b/src/connect/protocol_connect/constants.py index 6f9e530..bfe96c8 100644 --- a/src/connect/protocol_connect/constants.py +++ b/src/connect/protocol_connect/constants.py @@ -1,3 +1,5 @@ +"""Constants used in the Connect protocol implementation for Python.""" + from sys import version from connect.version import __version__ diff --git a/src/connect/protocol_connect/content_type.py b/src/connect/protocol_connect/content_type.py index 98ad9b1..c3141e0 100644 --- a/src/connect/protocol_connect/content_type.py +++ b/src/connect/protocol_connect/content_type.py @@ -1,3 +1,5 @@ +"""Utilities for handling Connect protocol content types.""" + from http import HTTPStatus from connect.code import Code diff --git a/src/connect/protocol_connect/end_stream.py b/src/connect/protocol_connect/end_stream.py index 96caf32..95b03cc 100644 --- a/src/connect/protocol_connect/end_stream.py +++ b/src/connect/protocol_connect/end_stream.py @@ -1,3 +1,5 @@ +"""Module for handling end-of-stream JSON serialization and deserialization for Connect protocol.""" + import json from typing import Any diff --git a/src/connect/protocol_connect/error_code.py b/src/connect/protocol_connect/error_code.py index 1291e60..349a9a1 100644 --- a/src/connect/protocol_connect/error_code.py +++ b/src/connect/protocol_connect/error_code.py @@ -1,3 +1,5 @@ +"""Module for mapping Connect error codes to HTTP status codes.""" + from connect.code import Code diff --git a/src/connect/protocol_connect/error_json.py b/src/connect/protocol_connect/error_json.py index 6830894..4d29e1c 100644 --- a/src/connect/protocol_connect/error_json.py +++ b/src/connect/protocol_connect/error_json.py @@ -1,3 +1,5 @@ +"""Module for serializing and deserializing ConnectError objects to and from JSON.""" + import base64 import contextlib import json diff --git a/src/connect/protocol_connect/marshaler.py b/src/connect/protocol_connect/marshaler.py index 342e4ae..af3c12d 100644 --- a/src/connect/protocol_connect/marshaler.py +++ b/src/connect/protocol_connect/marshaler.py @@ -1,3 +1,5 @@ +"""Provides marshaling utilities for the Connect protocol.""" + import base64 import contextlib import json diff --git a/src/connect/protocol_connect/unmarshaler.py b/src/connect/protocol_connect/unmarshaler.py index 5fd5615..64fbd5d 100644 --- a/src/connect/protocol_connect/unmarshaler.py +++ b/src/connect/protocol_connect/unmarshaler.py @@ -1,3 +1,5 @@ +"""Module providing classes for unmarshaling unary and streaming Connect protocol messages.""" + from collections.abc import ( AsyncIterable, AsyncIterator, diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py deleted file mode 100644 index 7c0251c..0000000 --- a/src/connect/protocol_grpc.py +++ /dev/null @@ -1,1356 +0,0 @@ -"""Provaides classes and functions for handling gRPC protocol.""" - -import asyncio -import base64 -import contextlib -import functools -import re -import sys -import urllib.parse -from collections.abc import AsyncIterable, AsyncIterator, Callable, Mapping -from copy import copy -from http import HTTPMethod -from typing import Any -from urllib.parse import unquote - -import httpcore -from google.protobuf.message import DecodeError -from google.rpc import status_pb2 -from yarl import URL - -from connect.byte_stream import HTTPCoreResponseAsyncByteStream -from connect.code import Code -from connect.codec import Codec, CodecNameType -from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name -from connect.connect import ( - Address, - Peer, - Spec, - StreamingClientConn, - StreamingHandlerConn, - StreamType, -) -from connect.envelope import EnvelopeFlags, EnvelopeReader, EnvelopeWriter -from connect.error import ConnectError, ErrorDetail -from connect.headers import Headers, include_request_headers -from connect.protocol import ( - HEADER_CONTENT_TYPE, - HEADER_USER_AGENT, - PROTOCOL_GRPC, - PROTOCOL_GRPC_WEB, - Protocol, - ProtocolClient, - ProtocolClientParams, - ProtocolHandler, - ProtocolHandlerParams, - code_from_http_status, - exclude_protocol_headers, - negotiate_compression, -) -from connect.request import Request -from connect.session import AsyncClientSession -from connect.streaming_response import StreamingResponse -from connect.utils import aiterate, map_httpcore_exceptions -from connect.version import __version__ -from connect.writer import ServerResponseWriter - -GRPC_HEADER_COMPRESSION = "Grpc-Encoding" -GRPC_HEADER_ACCEPT_COMPRESSION = "Grpc-Accept-Encoding" -GRPC_HEADER_TIMEOUT = "Grpc-Timeout" -GRPC_HEADER_STATUS = "Grpc-Status" -GRPC_HEADER_MESSAGE = "Grpc-Message" -GRPC_HEADER_DETAILS = "Grpc-Status-Details-Bin" - -GRPC_CONTENT_TYPE_DEFAULT = "application/grpc" -GRPC_WEB_CONTENT_TYPE_DEFAULT = "application/grpc-web" -GRPC_CONTENT_TYPE_PREFIX = GRPC_CONTENT_TYPE_DEFAULT + "+" -GRPC_WEB_CONTENT_TYPE_PREFIX = GRPC_WEB_CONTENT_TYPE_DEFAULT + "+" - -HEADER_X_USER_AGENT = "X-User-Agent" - -GRPC_ALLOWED_METHODS = [HTTPMethod.POST] - -DEFAULT_GRPC_USER_AGENT = f"connect-python/{__version__} (Python/{__version__})" - - -_RE = re.compile(r"^(\d{1,8})([HMSmun])$") -_UNIT_TO_SECONDS = { - "n": 1e-9, # nanosecond - "u": 1e-6, # microsecond - "m": 1e-3, # millisecond - "S": 1.0, - "M": 60.0, - "H": 3600.0, -} -_MAX_HOURS = sys.maxsize // (60 * 60 * 1_000_000_000) - - -class ProtocolGRPC(Protocol): - """ProtocolGRPC is a protocol implementation for handling gRPC and gRPC-Web requests. - - Attributes: - web (bool): Indicates whether to use gRPC-Web (True) or standard gRPC (False). - - """ - - def __init__(self, web: bool) -> None: - """Initialize the instance. - - Args: - web (bool): Indicates whether the instance is for web usage. - - """ - self.web = web - - def handler(self, params: ProtocolHandlerParams) -> ProtocolHandler: - """Create and returns a GRPCHandler instance configured with appropriate content types based on the provided parameters. - - Args: - params (ProtocolHandlerParams): The parameters containing codec information and other handler configuration. - - Returns: - ProtocolHandler: An instance of GRPCHandler initialized with the correct content types for gRPC or gRPC-Web. - - Behavior: - - Determines the default and prefix content types based on whether gRPC-Web is enabled. - - Constructs a list of supported content types from the available codecs. - - Adds the bare content type if the PROTO codec is present. - - Returns a GRPCHandler with the computed content types. - - """ - bare, prefix = GRPC_CONTENT_TYPE_DEFAULT, GRPC_CONTENT_TYPE_PREFIX - if self.web: - bare, prefix = GRPC_WEB_CONTENT_TYPE_DEFAULT, GRPC_WEB_CONTENT_TYPE_PREFIX - - content_types: list[str] = [] - for name in params.codecs.names(): - content_types.append(prefix + name) - - if params.codecs.get(CodecNameType.PROTO): - content_types.append(bare) - - return GRPCHandler(params, self.web, content_types) - - def client(self, params: ProtocolClientParams) -> ProtocolClient: - """Create and return a GRPCClient instance. - - Args: - params (ProtocolClientParams): The parameters required to initialize the client. - - Returns: - ProtocolClient: An instance of GRPCClient. - - """ - peer = Peer( - address=Address(host=params.url.host or "", port=params.url.port or 80), - protocol=PROTOCOL_GRPC, - query={}, - ) - if self.web: - peer.protocol = PROTOCOL_GRPC_WEB - - return GRPCClient(params, peer, self.web) - - -class GRPCHandler(ProtocolHandler): - """GRPCHandler is a protocol handler for gRPC and gRPC-Web requests. - - This class implements the ProtocolHandler interface to handle gRPC protocol requests, - including negotiation of compression, codec selection, and connection management for - both standard gRPC and gRPC-Web. It supports content type negotiation, payload handling, - and manages the lifecycle of a gRPC connection, including streaming and non-streaming - requests. - - Attributes: - params (ProtocolHandlerParams): Configuration parameters for the handler, including codecs and compressions. - web (bool): Indicates if the handler is for gRPC-Web. - accept (list[str]): List of accepted content types. - - """ - - params: ProtocolHandlerParams - web: bool - accept: list[str] - - def __init__(self, params: ProtocolHandlerParams, web: bool, accept: list[str]) -> None: - """Initialize the ProtocolHandler with the given parameters. - - Args: - params (ProtocolHandlerParams): The parameters required for the protocol handler. - web (bool): Indicates whether the handler is for web usage. - accept (list[str]): A list of accepted content types. - - Returns: - None - - """ - self.params = params - self.web = web - self.accept = accept - - @property - def methods(self) -> list[HTTPMethod]: - """Returns a list of allowed HTTP methods for gRPC protocol. - - Returns: - list[HTTPMethod]: A list containing the HTTP methods permitted for gRPC communication. - - """ - return GRPC_ALLOWED_METHODS - - def content_types(self) -> list[str]: - """Return a list of accepted content types. - - Returns: - list[str]: A list of MIME types that are accepted. - - """ - return self.accept - - def can_handle_payload(self, _: Request, content_type: str) -> bool: - """Determine if the given content type is supported by this handler. - - Args: - _ (Request): The request object (unused). - content_type (str): The MIME type of the payload to check. - - Returns: - bool: True if the content type is accepted, False otherwise. - - """ - return content_type in self.accept - - async def conn( - self, - request: Request, - response_headers: Headers, - response_trailers: Headers, - writer: ServerResponseWriter, - ) -> StreamingHandlerConn | None: - """Handle a connection request. - - Args: - request (Request): The incoming request object. - response_headers (Headers): The headers to be sent in the response. - response_trailers (Headers): The trailers to be sent in the response. - writer (ServerResponseWriter): The writer used to send the response. - is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. - - Returns: - StreamingHandlerConn | None: The connection handler or None if not implemented. - - """ - content_encoding = request.headers.get(GRPC_HEADER_COMPRESSION) - accept_encoding = request.headers.get(GRPC_HEADER_ACCEPT_COMPRESSION) - - request_compression, response_compression, error = negotiate_compression( - self.params.compressions, content_encoding, accept_encoding - ) - - response_headers[HEADER_CONTENT_TYPE] = request.headers.get(HEADER_CONTENT_TYPE, "") - response_headers[GRPC_HEADER_ACCEPT_COMPRESSION] = f"{', '.join(c.name for c in self.params.compressions)}" - if response_compression and response_compression.name != COMPRESSION_IDENTITY: - response_headers[GRPC_HEADER_COMPRESSION] = response_compression.name - - codec_name = grpc_codec_from_content_type(self.web, request.headers.get(HEADER_CONTENT_TYPE, "")) - codec = self.params.codecs.get(codec_name) - protocol_name = PROTOCOL_GRPC if not self.web else PROTOCOL_GRPC + "-web" - - peer = Peer( - address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, - protocol=protocol_name, - query=request.query_params, - ) - - conn = GRPCHandlerConn( - web=self.web, - writer=writer, - spec=self.params.spec, - peer=peer, - marshaler=GRPCMarshaler( - codec, - response_compression, - self.params.compress_min_bytes, - self.params.send_max_bytes, - ), - unmarshaler=GRPCUnmarshaler( - self.web, - codec, - self.params.read_max_bytes, - request.stream(), - request_compression, - ), - request_headers=Headers(request.headers, encoding="latin-1"), - response_headers=response_headers, - response_trailers=response_trailers, - ) - - if error: - await conn.send_error(error) - return None - - return conn - - -class GRPCClient(ProtocolClient): - """GRPCClient is a protocol client implementation for gRPC communication, supporting both standard and web environments. - - Attributes: - params (ProtocolClientParams): Configuration parameters for the protocol client, including codec, compression, session, and URL. - _peer (Peer): The peer instance associated with this client, representing the remote endpoint. - web (bool): Indicates whether the client is running in a web environment, affecting header and content-type handling. - - """ - - params: ProtocolClientParams - _peer: Peer - web: bool - - def __init__(self, params: ProtocolClientParams, peer: Peer, web: bool) -> None: - """Initialize the ProtocolClient with the given parameters. - - Args: - params (ProtocolClientParams): The parameters for the protocol client. - peer (Peer): The peer instance to be used. - web (bool): Indicates whether the client is running in a web environment. - - """ - self.params = params - self._peer = peer - self.web = web - - @property - def peer(self) -> Peer: - """Returns the associated Peer object. - - Returns: - Peer: The peer instance associated with this object. - - """ - return self._peer - - def write_request_headers(self, _: StreamType, headers: Headers) -> None: - """Set and modifies HTTP/2 or gRPC request headers based on the stream type, connection parameters, and environment. - - Args: - stream_type (StreamType): The type of stream for which headers are being written. - headers (Headers): The dictionary of headers to be modified or populated. - - Behavior: - - Ensures the 'User-Agent' header is set to the default gRPC user agent if not already present. - - If running in a web environment, also sets the 'X-User-Agent' header. - - Sets the 'Content-Type' header according to the codec name and environment. - - Sets the 'Accept-Encoding' header to indicate supported compression. - - If a specific compression is configured and is not the identity, sets the gRPC compression header. - - If multiple compressions are supported, sets the gRPC accept compression header with the supported values. - - For non-web environments, adds the 'Te: trailers' header required for gRPC. - - Note: - This method mutates the provided headers dictionary in place. - - """ - if headers.get(HEADER_USER_AGENT, None) is None: - headers[HEADER_USER_AGENT] = DEFAULT_GRPC_USER_AGENT - - if self.web and headers.get(HEADER_X_USER_AGENT, None) is None: - headers[HEADER_X_USER_AGENT] = DEFAULT_GRPC_USER_AGENT - - headers[HEADER_CONTENT_TYPE] = grpc_content_type_from_codec_name(self.web, self.params.codec.name) - - headers["Accept-Encoding"] = COMPRESSION_IDENTITY - if self.params.compression_name and self.params.compression_name != COMPRESSION_IDENTITY: - headers[GRPC_HEADER_COMPRESSION] = self.params.compression_name - - if self.params.compressions: - headers[GRPC_HEADER_ACCEPT_COMPRESSION] = ", ".join(c.name for c in self.params.compressions) - - if not self.web: - headers["Te"] = "trailers" - - def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: - """Create and returns a GRPCClientConn instance configured with the provided specification and headers. - - Args: - spec (Spec): The specification object defining the protocol or service interface. - headers (Headers): The request headers to include in the connection. - - Returns: - StreamingClientConn: An initialized gRPC streaming client connection. - - Details: - - Configures the connection with parameters such as session, peer, URL, codec, and compression settings. - - Initializes GRPCMarshaler and GRPCUnmarshaler with appropriate codecs and limits. - - Compression is determined using the provided compression name and available compressions. - - """ - return GRPCClientConn( - web=self.web, - session=self.params.session, - spec=spec, - peer=self.peer, - url=self.params.url, - codec=self.params.codec, - compressions=self.params.compressions, - marshaler=GRPCMarshaler( - codec=self.params.codec, - compress_min_bytes=self.params.compress_min_bytes, - send_max_bytes=self.params.send_max_bytes, - compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), - ), - unmarshaler=GRPCUnmarshaler( - web=self.web, - codec=self.params.codec, - read_max_bytes=self.params.read_max_bytes, - ), - request_headers=headers, - ) - - -class GRPCMarshaler(EnvelopeWriter): - """GRPCMarshaler is responsible for marshaling messages into the gRPC wire format. - - Args: - codec (Codec | None): The codec used for encoding/decoding messages. - compression (Compression | None): The compression algorithm to use, if any. - compress_min_bytes (int): Minimum message size in bytes before compression is applied. - send_max_bytes (int): Maximum allowed size of a message to send. - - Methods: - marshal(messages: AsyncIterable[bytes]) -> AsyncIterator[bytes]: - Asynchronously marshals a stream of message bytes into the gRPC wire format. - Yields marshaled message bytes ready for transmission. - - """ - - def __init__( - self, - codec: Codec | None, - compression: Compression | None, - compress_min_bytes: int, - send_max_bytes: int, - ) -> None: - """Initialize the protocol with the specified configuration. - - Args: - codec (Codec | None): The codec to use for encoding/decoding messages, or None for default. - compression (Compression | None): The compression algorithm to use, or None for no compression. - compress_min_bytes (int): The minimum number of bytes before compression is applied. - send_max_bytes (int): The maximum number of bytes allowed to send in a single message. - - Returns: - None - - """ - super().__init__(codec, compression, compress_min_bytes, send_max_bytes) - - async def marshal_web_trailers(self, trailers: Headers) -> bytes: - """Serialize HTTP trailer headers into a gRPC-Web trailer envelope. - - Args: - trailers (Headers): A dictionary-like object containing HTTP trailer headers. - - Returns: - bytes: The serialized gRPC-Web trailer envelope containing the trailer headers. - - """ - lines = [] - for key, value in trailers.items(): - lines.append(f"{key}: {value}\r\n") - - env = self.write_envelope("".join(lines).encode(), EnvelopeFlags.trailer) - - return env.encode() - - -class GRPCUnmarshaler(EnvelopeReader): - """GRPCUnmarshaler is a specialized EnvelopeReader for handling gRPC message unmarshaling. - - Args: - codec (Codec | None): The codec used for decoding messages. - read_max_bytes (int): The maximum number of bytes to read from the stream. - stream (AsyncIterable[bytes] | None, optional): The asynchronous byte stream to read messages from. - compression (Compression | None, optional): Compression algorithm to use for decompressing messages. - - Methods: - async unmarshal(message: Any) -> AsyncIterator[Any]: - Asynchronously unmarshals the given message, yielding each decoded object. - Iterates over the results of the internal _unmarshal method, yielding only the object part of each tuple. - - """ - - web: bool - _web_trailers: Headers | None - - def __init__( - self, - web: bool, - codec: Codec | None, - read_max_bytes: int, - stream: AsyncIterable[bytes] | None = None, - compression: Compression | None = None, - ) -> None: - """Initialize the protocol gRPC handler. - - Args: - web (bool): Indicates if the connection is for a web environment. - codec (Codec | None): The codec to use for encoding/decoding messages. Can be None. - read_max_bytes (int): The maximum number of bytes to read from the stream. - stream (AsyncIterable[bytes] | None, optional): An asynchronous iterable stream of bytes. Defaults to None. - compression (Compression | None, optional): The compression method to use. Defaults to None. - - """ - super().__init__(codec, read_max_bytes, stream, compression) - self.web = web - self._web_trailers = None - - async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: - """Asynchronously unmarshals a given message and yields each resulting object. - - Args: - message (Any): The message to be unmarshaled. - - Yields: - Any: Each object obtained from unmarshaling the message. - - """ - async for obj, end in super().unmarshal(message): - if end: - env = self.last - if not env: - raise ConnectError("protocol error: empty envelope") - - data = copy(env.data) - env.data = b"" - - if not (self.web and env.is_set(EnvelopeFlags.trailer)): - raise ConnectError( - f"protocol error: invalid envelope flags: {env.flags}", - ) - - trailers = Headers() - lines = data.decode("utf-8").splitlines() - for line in lines: - if line == "": - continue - - name, value = line.split(":", 1) - name = name.strip().lower() - value = value.strip() - if name in trailers: - trailers[name] += "," + value - else: - trailers[name] = value - - self._web_trailers = trailers - - yield obj, end - - @property - def web_trailers(self) -> Headers | None: - """Return the trailers received in the last envelope. - - Returns: - Headers | None: The trailers received in the last envelope, or None if no trailers were received. - - """ - return self._web_trailers - - -EventHook = Callable[..., Any] - - -class GRPCClientConn(StreamingClientConn): - """GRPCClientConn is a gRPC client connection implementation supporting asynchronous streaming requests and responses over HTTP/2. - - This class manages the lifecycle of a gRPC client connection, including marshaling and unmarshaling messages, handling request and response headers/trailers, managing compression, and supporting event hooks for request/response events. It integrates with an asynchronous HTTP client session and supports cancellation via asyncio events. - - Attributes: - session (AsyncClientSession): The asynchronous client session used for HTTP requests. - _spec (Spec): The protocol or API specification. - _peer (Peer): Information about the remote peer. - url (URL): The endpoint URL for the connection. - codec (Codec | None): Codec for encoding/decoding messages. - compressions (list[Compression]): Supported compression algorithms. - marshaler (GRPCMarshaler): Marshaler for serializing messages. - unmarshaler (GRPCUnmarshaler): Unmarshaler for deserializing messages. - _response_headers (Headers): HTTP response headers. - _response_trailers (Headers): HTTP response trailers. - _request_headers (Headers): HTTP request headers. - receive_trailers (Callable[[], None] | None): Callback to receive trailers after response. - - """ - - web: bool - session: AsyncClientSession - _spec: Spec - _peer: Peer - url: URL - codec: Codec | None - compressions: list[Compression] - marshaler: GRPCMarshaler - unmarshaler: GRPCUnmarshaler - _response_headers: Headers - _response_trailers: Headers - _request_headers: Headers - receive_trailers: Callable[[], None] | None - - def __init__( - self, - web: bool, - session: AsyncClientSession, - spec: Spec, - peer: Peer, - url: URL, - codec: Codec | None, - compressions: list[Compression], - request_headers: Headers, - marshaler: GRPCMarshaler, - unmarshaler: GRPCUnmarshaler, - event_hooks: None | (Mapping[str, list[EventHook]]) = None, - ) -> None: - """Initialize a new instance of the class. - - Args: - web (bool): Indicates if the connection is for a web environment. - session (AsyncClientSession): The asynchronous client session to use for requests. - spec (Spec): The specification object describing the protocol or API. - peer (Peer): The peer information for the connection. - url (URL): The URL endpoint for the connection. - codec (Codec | None): The codec to use for encoding/decoding messages, or None. - compressions (list[Compression]): List of supported compression algorithms. - request_headers (Headers): Headers to include in outgoing requests. - marshaler (GRPCMarshaler): The marshaler for serializing messages. - unmarshaler (GRPCUnmarshaler): The unmarshaler for deserializing messages. - event_hooks (None | Mapping[str, list[EventHook]], optional): Optional mapping of event hooks for "request" and "response" events. Defaults to None. - - """ - event_hooks = {} if event_hooks is None else event_hooks - - self.web = web - self.session = session - self._spec = spec - self._peer = peer - self.url = url - self.codec = codec - self.compressions = compressions - self.marshaler = marshaler - self.unmarshaler = unmarshaler - self._response_headers = Headers() - self._response_trailers = Headers() - self._request_headers = request_headers - - self._event_hooks = { - "request": list(event_hooks.get("request", [])), - "response": list(event_hooks.get("response", [])), - } - - @property - def spec(self) -> Spec: - """Return the specification details.""" - return self._spec - - @property - def peer(self) -> Peer: - """Return the peer information.""" - raise NotImplementedError() - - async def receive(self, message: Any, abort_event: asyncio.Event | None) -> AsyncIterator[Any]: - """Receives a message and processes it.""" - trailer_received = False - - async for obj, end in self.unmarshaler.unmarshal(message): - if abort_event and abort_event.is_set(): - raise ConnectError("receive operation aborted", Code.CANCELED) - - if end: - if trailer_received: - raise ConnectError("received extra end stream trailer", Code.INVALID_ARGUMENT) - - trailer_received = True - if self.unmarshaler.web_trailers is None: - raise ConnectError("trailer not received", Code.INVALID_ARGUMENT) - - continue - - if trailer_received: - raise ConnectError("protocol error: received extra message after trailer", Code.INVALID_ARGUMENT) - - yield obj - - if callable(self.receive_trailers): - self.receive_trailers() - - if self.unmarshaler.bytes_read == 0 and len(self.response_trailers) == 0: - self.response_trailers.update(self._response_headers) - del self._response_headers[HEADER_CONTENT_TYPE] - - server_error = grpc_error_from_trailer(self.response_trailers) - if server_error: - server_error.metadata = self.response_headers.copy() - raise server_error - - server_error = grpc_error_from_trailer(self.response_trailers) - if server_error: - server_error.metadata = self.response_headers.copy() - server_error.metadata.update(self.response_trailers) - raise server_error - - def _receive_trailers(self, response: httpcore.Response) -> None: - if self.web: - trailers = self.unmarshaler.web_trailers - if trailers is not None: - self._response_trailers.update(trailers) - - else: - if "trailing_headers" not in response.extensions: - return - - trailers = response.extensions["trailing_headers"] - self._response_trailers.update(Headers(trailers)) - - @property - def request_headers(self) -> Headers: - """Return the request headers.""" - return self._request_headers - - async def send( - self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None - ) -> None: - """Send a gRPC request asynchronously using HTTP/2 via httpcore, handling streaming messages, timeouts, and abort events. - - Args: - messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled and sent as the request body. - timeout (float | None): Optional timeout in seconds for the request. If provided, sets the gRPC timeout header. - abort_event (asyncio.Event | None): Optional asyncio event that, if set, will abort the request and raise a cancellation error. - - Raises: - ConnectError: If the request is aborted before or during execution, or if an error occurs during the HTTP request. - - Side Effects: - - Invokes registered request and response event hooks. - - Sets up the response stream and trailers for further processing. - - Validates the HTTP response. - - """ - extensions = {} - if timeout: - extensions["timeout"] = {"read": timeout} - self._request_headers[GRPC_HEADER_TIMEOUT] = grpc_encode_timeout(timeout) - - content_iterator = self.marshaler.marshal(messages) - - request = httpcore.Request( - method=HTTPMethod.POST, - url=httpcore.URL( - scheme=self.url.scheme, - host=self.url.host or "", - port=self.url.port, - target=self.url.raw_path, - ), - headers=list( - include_request_headers( - headers=self._request_headers, url=self.url, content=content_iterator, method=HTTPMethod.POST - ).items() - ), - content=content_iterator, - extensions=extensions, - ) - - for hook in self._event_hooks["request"]: - hook(request) - - with map_httpcore_exceptions(): - if not abort_event: - response = await self.session.pool.handle_async_request(request) - else: - request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) - abort_task = asyncio.create_task(abort_event.wait()) - - done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) - - if abort_task in done: - request_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await request_task - - raise ConnectError("request aborted", Code.CANCELED) - - abort_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await abort_task - - response = await request_task - - for hook in self._event_hooks["response"]: - hook(response) - - assert isinstance(response.stream, AsyncIterable) - self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(aiterator=response.stream) - self.receive_trailers = functools.partial(self._receive_trailers, response) - - await self._validate_response(response) - - async def _validate_response(self, response: httpcore.Response) -> None: - response_headers = Headers(response.headers) - if response.status != 200: - raise ConnectError( - f"HTTP {response.status}", - code_from_http_status(response.status), - ) - - grpc_validate_response_content_type( - self.web, - self.marshaler.codec.name if self.marshaler.codec else "", - response_headers.get(HEADER_CONTENT_TYPE, ""), - ) - - compression = response_headers.get(GRPC_HEADER_COMPRESSION, None) - if compression and compression != COMPRESSION_IDENTITY: - self.unmarshaler.compression = get_compresion_from_name(compression, self.compressions) - - self._response_headers.update(response_headers) - - @property - def response_headers(self) -> Headers: - """Return the response headers.""" - return self._response_headers - - @property - def response_trailers(self) -> Headers: - """Return response trailers.""" - return self._response_trailers - - def on_request_send(self, fn: EventHook) -> None: - """Register a callback function to be invoked when a request is sent. - - Args: - fn (EventHook): The callback function to be added to the "request" event hook. - - Returns: - None - - """ - self._event_hooks["request"].append(fn) - - async def aclose(self) -> None: - """Asynchronously closes the underlying unmarshaler resource. - - This method should be called to properly release any resources held by the unmarshaler, - such as open network connections or file handles, when they are no longer needed. - """ - await self.unmarshaler.aclose() - - -class GRPCHandlerConn(StreamingHandlerConn): - """GRPCHandlerConn is a handler class for managing gRPC protocol connections within a streaming server context. - - This class encapsulates the logic for handling gRPC requests and responses, including marshaling and unmarshaling messages, - managing request and response headers/trailers, handling timeouts, and enforcing protocol-specific constraints for unary and streaming operations. - - Attributes: - _spec (Spec): The specification object describing the protocol or service. - _peer (Peer): The peer information for the current connection. - _request_headers (Headers): The headers received with the request. - _response_headers (Headers): The headers to include in the response. - _response_trailers (Headers): The trailers to include in the response. - _is_streaming (bool): Indicates if the connection is streaming. - - """ - - web: bool - _spec: Spec - _peer: Peer - writer: ServerResponseWriter - marshaler: GRPCMarshaler - unmarshaler: GRPCUnmarshaler - _request_headers: Headers - _response_headers: Headers - _response_trailers: Headers - - def __init__( - self, - web: bool, - writer: ServerResponseWriter, - spec: Spec, - peer: Peer, - marshaler: GRPCMarshaler, - unmarshaler: GRPCUnmarshaler, - request_headers: Headers, - response_headers: Headers, - response_trailers: Headers | None = None, - ) -> None: - """Initialize a new instance of the class. - - Args: - web (bool): Indicates if the connection is for a web environment. - writer (ServerResponseWriter): The writer used to send responses to the client. - spec (Spec): The specification object describing the protocol or service. - peer (Peer): The peer information for the current connection. - marshaler (GRPCMarshaler): The marshaler used to serialize response messages. - unmarshaler (GRPCUnmarshaler): The unmarshaler used to deserialize request messages. - request_headers (Headers): The headers received with the request. - response_headers (Headers): The headers to include in the response. - response_trailers (Headers | None, optional): The trailers to include in the response. Defaults to None. - is_streaming (bool, optional): Indicates if the connection is streaming. Defaults to False. - - """ - self.web = web - self.writer = writer - self._spec = spec - self._peer = peer - self.marshaler = marshaler - self.unmarshaler = unmarshaler - self._request_headers = request_headers - self._response_headers = response_headers - self._response_trailers = response_trailers if response_trailers is not None else Headers() - - def parse_timeout(self) -> float | None: - """Parse the gRPC timeout value from the request headers and returns it as seconds. - - Returns: - float | None: The timeout value in seconds if present and valid, otherwise None. - - Raises: - ConnectError: If the timeout value is present but invalid or too long. - - Notes: - - The timeout is extracted from the gRPC header and must match the expected format. - - If the timeout unit is hours and exceeds the maximum allowed, None is returned. - - """ - timeout = self._request_headers.get(GRPC_HEADER_TIMEOUT) - if not timeout: - return None - - m = _RE.match(timeout) - if m is None: - raise ConnectError(f"protocol error: invalid grpc timeout value: {timeout}") - - num_str, unit = m.groups() - num = int(num_str) - if num > 99_999_999: - raise ConnectError(f"protocol error: timeout {timeout!r} is too long") - - if unit == "H" and num > _MAX_HOURS: - return None - - seconds = num * _UNIT_TO_SECONDS[unit] - return seconds - - @property - def spec(self) -> Spec: - """Returns the specification object associated with this instance. - - Returns: - Spec: The specification object. - - """ - return self._spec - - @property - def peer(self) -> Peer: - """Returns the associated Peer object. - - Returns: - Peer: The peer instance associated with this object. - - """ - return self._peer - - async def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message and processes it. - - Args: - message (Any): The message to be received and processed. - - Returns: - AsyncIterator[Any]: An async iterator yielding message(s). For non-streaming operations, - this will yield exactly one item. - - """ - async for obj, _ in self.unmarshaler.unmarshal(message): - yield obj - - @property - def request_headers(self) -> Headers: - """Returns the headers associated with the current request. - - Returns: - Headers: The headers of the request. - - """ - return self._request_headers - - async def send(self, messages: AsyncIterable[Any]) -> None: - """Send message(s) by marshaling them into bytes. - - Args: - messages (AsyncIterable[Any]): The message(s) to be sent. For unary operations, - this should be an iterable with a single item. - - Returns: - None - - """ - if self.web: - await self.writer.write( - StreamingResponse( - content=self._send_messages(messages), - headers=self.response_headers, - status_code=200, - ) - ) - else: - await self.writer.write( - StreamingResponse( - content=self._send_messages(messages), - headers=self.response_headers, - trailers=self.response_trailers, - status_code=200, - ) - ) - - @property - def response_headers(self) -> Headers: - """Returns the response headers associated with the current request. - - Returns: - Headers: The headers returned in the response. - - """ - return self._response_headers - - @property - def response_trailers(self) -> Headers: - """Returns the response trailers as headers. - - Response trailers are additional metadata sent by the server after the response body, - typically used in gRPC and HTTP/2 protocols. - - Returns: - Headers: The response trailers associated with the current response. - - """ - return self._response_trailers - - async def _send_messages(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: - """Asynchronously sends marshaled messages and yields them as byte streams. - - Args: - messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled and sent. - - Yields: - bytes: Marshaled message bytes, and optionally marshaled web trailers if in web mode. - - Raises: - ConnectError: If an error occurs during marshaling or sending messages, a ConnectError is set and handled. - - Notes: - - Errors encountered during message marshaling are converted to ConnectError and added to response trailers. - - If running in web mode (`self.web` is True), marshaled web trailers are yielded at the end. - - """ - error: ConnectError | None = None - try: - async for msg in self.marshaler.marshal(messages): - yield msg - except Exception as e: - error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) - finally: - grpc_error_to_trailer(self.response_trailers, error) - - if self.web: - body = await self.marshaler.marshal_web_trailers(self.response_trailers) - yield body - - async def send_error(self, error: ConnectError) -> None: - """Send an error response over gRPC by converting the provided ConnectError into gRPC trailers. - - Args: - error (ConnectError): The error to be sent as a gRPC trailer. - - Returns: - None - - This method updates the response trailers with the error information and writes a streaming response - with the appropriate headers and trailers to the client. - - """ - grpc_error_to_trailer(self.response_trailers, error) - if self.web: - body = await self.marshaler.marshal_web_trailers(self.response_trailers) - - await self.writer.write( - StreamingResponse( - content=aiterate([body]), - headers=self.response_headers, - status_code=200, - ) - ) - else: - await self.writer.write( - StreamingResponse( - content=[], - headers=self.response_headers, - trailers=self.response_trailers, - status_code=200, - ) - ) - - -def grpc_codec_from_content_type(web: bool, content_type: str) -> str: - """Determine the gRPC codec name from the given content type string. - - Args: - web (bool): Indicates whether the request is a gRPC-web request. - content_type (str): The content type string to parse. - - Returns: - str: The codec name extracted from the content type. If the content type matches the default gRPC or gRPC-web content type, - returns the default codec name. Otherwise, extracts and returns the codec name from the content type prefix, or returns - the original content type if no known prefix is found. - - """ - if (not web and content_type == GRPC_CONTENT_TYPE_DEFAULT) or ( - web and content_type == GRPC_WEB_CONTENT_TYPE_DEFAULT - ): - return CodecNameType.PROTO - - prefix = GRPC_CONTENT_TYPE_PREFIX if not web else GRPC_WEB_CONTENT_TYPE_PREFIX - - if content_type.startswith(prefix): - return content_type[len(prefix) :] - else: - return content_type - - -def grpc_error_to_trailer(trailer: Headers, error: ConnectError | None) -> None: - """Convert a ConnectError to gRPC trailer headers. - - Args: - trailer (Headers): The trailer headers dictionary to update with gRPC error information. - error (ConnectError | None): The error to convert. If None, indicates success. - - Side Effects: - Modifies the `trailer` dictionary in-place to include gRPC status, message, and optional details. - - Notes: - - If `error` is None, sets the gRPC status header to "0" (OK). - - If `ConnectError.wire_error` is False, updates the trailer with error metadata excluding protocol headers. - - Serializes error details using protobuf if present, encoding them in base64 for the trailer. - - """ - if error is None: - trailer[GRPC_HEADER_STATUS] = "0" - return - - if not ConnectError.wire_error: - trailer.update(exclude_protocol_headers(error.metadata)) - - status = status_pb2.Status( - code=error.code.value, - message=error.raw_message, - details=error.details_any(), - ) - code = status.code - message = status.message - details_binary = None - - if len(status.details) > 0: - details_binary = status.SerializeToString() - - trailer[GRPC_HEADER_STATUS] = str(code) - trailer[GRPC_HEADER_MESSAGE] = urllib.parse.quote(message) - if details_binary: - trailer[GRPC_HEADER_DETAILS] = base64.b64encode(details_binary).decode().rstrip("=") - - -def grpc_content_type_from_codec_name(web: bool, codec_name: str) -> str: - """Return the appropriate gRPC content type string based on the given codec name and whether the request is for gRPC-Web. - - Args: - web (bool): Indicates if the content type is for gRPC-Web (True) or standard gRPC (False). - codec_name (str): The name of the codec (e.g., "proto", "json"). - - Returns: - str: The corresponding gRPC content type string. - - """ - if web: - return GRPC_WEB_CONTENT_TYPE_PREFIX + codec_name - - if codec_name == CodecNameType.PROTO: - return GRPC_CONTENT_TYPE_DEFAULT - - return GRPC_CONTENT_TYPE_PREFIX + codec_name - - -def grpc_validate_response_content_type(web: bool, request_codec_name: str, response_content_type: str) -> None: - """Validate that the gRPC response content type matches the expected value based on the request codec and whether gRPC-Web is used. - - Args: - web (bool): Indicates if gRPC-Web is being used. - request_codec_name (str): The name of the codec used in the request (e.g., "proto", "json"). - response_content_type (str): The content type returned in the response. - - Raises: - ConnectError: If the response content type does not match the expected value, with an appropriate error code. - - """ - bare, prefix = GRPC_CONTENT_TYPE_DEFAULT, GRPC_CONTENT_TYPE_PREFIX - if web: - bare, prefix = GRPC_WEB_CONTENT_TYPE_DEFAULT, GRPC_WEB_CONTENT_TYPE_PREFIX - - if response_content_type == prefix + request_codec_name or ( - request_codec_name == CodecNameType.PROTO and response_content_type == bare - ): - return - - expected_content_type = bare - if request_codec_name != CodecNameType.PROTO: - expected_content_type = prefix + request_codec_name - - code = Code.INTERNAL - if response_content_type != bare and not response_content_type.startswith(prefix): - code = Code.UNKNOWN - - raise ConnectError(f"invalid content-type {response_content_type}, expected {expected_content_type}", code) - - -def grpc_error_from_trailer(trailers: Headers) -> ConnectError | None: - """Parse gRPC error information from response trailers and constructs a ConnectError if present. - - Args: - trailers (Headers): The gRPC response trailers containing error information. - - Returns: - ConnectError | None: Returns a ConnectError instance if an error is found in the trailers, - or None if the status code indicates success. - - Raises: - ConnectError: If the grpc-status-details-bin trailer or protobuf error details are invalid. - - The function extracts the gRPC status code, error message, and optional error details from the trailers. - If the status code is missing or invalid, it returns a ConnectError with an appropriate message. - If the status code indicates success ("0"), it returns None. - If error details are present and valid, they are attached to the ConnectError. - - """ - code_header = trailers.get(GRPC_HEADER_STATUS) - if code_header is None: - code = Code.UNKNOWN - if len(trailers) == 0: - code = Code.INTERNAL - - return ConnectError( - f"protocol error: no {GRPC_HEADER_STATUS} header in trailers", - code, - ) - - if code_header == "0": - return None - - try: - code = Code(int(code_header)) - except ValueError: - return ConnectError( - f"protocol error: invalid error code {code_header} in trailers", - ) - - try: - message = unquote(trailers.get(GRPC_HEADER_MESSAGE, "")) - except Exception: - return ConnectError( - f"protocol error: invalid error message {code_header} in trailers", - code=Code.UNKNOWN, - ) - - ret_error = ConnectError( - message, - code, - wire_error=True, - ) - - details_binary_encoded = trailers.get(GRPC_HEADER_DETAILS, None) - if details_binary_encoded and len(details_binary_encoded) > 0: - try: - details_binary = decode_binary_header(details_binary_encoded) - except Exception as e: - raise ConnectError( - f"server returned invalid grpc-status-details-bin trailer: {e}", - code=Code.INTERNAL, - ) from e - - status = status_pb2.Status() - try: - status.ParseFromString(details_binary) - except DecodeError as e: - raise ConnectError( - f"server returned invalid protobuf for error details: {e}", - code=Code.INTERNAL, - ) from e - - for detail in status.details: - ret_error.details.append(ErrorDetail(pb_any=detail)) - - ret_error.code = Code(status.code) - ret_error.raw_message = status.message - - return ret_error - - -def decode_binary_header(data: str) -> bytes: - """Decode a base64-encoded string representing a binary header. - - If the input string's length is not a multiple of 4, it pads the string with '=' characters - to make it valid base64 before decoding. - - Args: - data (str): The base64-encoded string to decode. - - Returns: - bytes: The decoded binary data. - - Raises: - binascii.Error: If the input is not correctly base64-encoded. - - """ - if len(data) % 4: - data += "=" * (-len(data) % 4) - - return base64.b64decode(data, validate=True) - - -def grpc_encode_timeout(timeout: float) -> str: - """Encode a timeout value (in seconds) into the gRPC timeout format string. - - The gRPC timeout format is a decimal number with a time unit suffix, where the unit can be: - - 'H' for hours - - 'M' for minutes - - 'S' for seconds - - 'm' for milliseconds - - 'u' for microseconds - - 'n' for nanoseconds - - If the timeout is less than or equal to zero, returns "0n". - - Args: - timeout (float): The timeout value in seconds. - - Returns: - str: The timeout encoded as a gRPC timeout string. - - """ - if timeout <= 0: - return "0n" - - grpc_timeout_max_value = 10**8 - - _units = dict(sorted(_UNIT_TO_SECONDS.items(), key=lambda item: item[1])) - for unit, size in _units.items(): - if timeout < size * grpc_timeout_max_value: - value = int(timeout / size) - return f"{value}{unit}" - - value = int(timeout / 3600.0) - return f"{value}H" diff --git a/src/connect/protocol_grpc/__init__.py b/src/connect/protocol_grpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/connect/protocol_grpc/constants.py b/src/connect/protocol_grpc/constants.py new file mode 100644 index 0000000..faf48b4 --- /dev/null +++ b/src/connect/protocol_grpc/constants.py @@ -0,0 +1,38 @@ +"""Constants for gRPC protocol implementation in connect-python.""" + +import re +import sys +from http import HTTPMethod + +from connect.version import __version__ + +GRPC_HEADER_COMPRESSION = "Grpc-Encoding" +GRPC_HEADER_ACCEPT_COMPRESSION = "Grpc-Accept-Encoding" +GRPC_HEADER_TIMEOUT = "Grpc-Timeout" +GRPC_HEADER_STATUS = "Grpc-Status" +GRPC_HEADER_MESSAGE = "Grpc-Message" +GRPC_HEADER_DETAILS = "Grpc-Status-Details-Bin" + +GRPC_CONTENT_TYPE_DEFAULT = "application/grpc" +GRPC_WEB_CONTENT_TYPE_DEFAULT = "application/grpc-web" +GRPC_CONTENT_TYPE_PREFIX = GRPC_CONTENT_TYPE_DEFAULT + "+" +GRPC_WEB_CONTENT_TYPE_PREFIX = GRPC_WEB_CONTENT_TYPE_DEFAULT + "+" + +HEADER_X_USER_AGENT = "X-User-Agent" + +GRPC_ALLOWED_METHODS = [HTTPMethod.POST] + +DEFAULT_GRPC_USER_AGENT = f"connect-python/{__version__} (Python/{__version__})" + +RE_TIMEOUT = re.compile(r"^(\d{1,8})([HMSmun])$") + +UNIT_TO_SECONDS = { + "n": 1e-9, # nanosecond + "u": 1e-6, # microsecond + "m": 1e-3, # millisecond + "S": 1.0, + "M": 60.0, + "H": 3600.0, +} + +MAX_HOURS = sys.maxsize // (60 * 60 * 1_000_000_000) diff --git a/src/connect/protocol_grpc/content_type.py b/src/connect/protocol_grpc/content_type.py new file mode 100644 index 0000000..33c8289 --- /dev/null +++ b/src/connect/protocol_grpc/content_type.py @@ -0,0 +1,89 @@ +"""Utilities for handling gRPC and gRPC-Web content types and codec validation.""" + +from connect.code import Code +from connect.codec import CodecNameType +from connect.error import ConnectError +from connect.protocol_grpc.constants import ( + GRPC_CONTENT_TYPE_DEFAULT, + GRPC_CONTENT_TYPE_PREFIX, + GRPC_WEB_CONTENT_TYPE_DEFAULT, + GRPC_WEB_CONTENT_TYPE_PREFIX, +) + + +def grpc_content_type_from_codec_name(web: bool, codec_name: str) -> str: + """Return the appropriate gRPC content type string based on the given codec name and whether the request is for gRPC-Web. + + Args: + web (bool): Indicates if the content type is for gRPC-Web (True) or standard gRPC (False). + codec_name (str): The name of the codec (e.g., "proto", "json"). + + Returns: + str: The corresponding gRPC content type string. + + """ + if web: + return GRPC_WEB_CONTENT_TYPE_PREFIX + codec_name + + if codec_name == CodecNameType.PROTO: + return GRPC_CONTENT_TYPE_DEFAULT + + return GRPC_CONTENT_TYPE_PREFIX + codec_name + + +def grpc_codec_from_content_type(web: bool, content_type: str) -> str: + """Determine the gRPC codec name from the given content type string. + + Args: + web (bool): Indicates whether the request is a gRPC-web request. + content_type (str): The content type string to parse. + + Returns: + str: The codec name extracted from the content type. If the content type matches the default gRPC or gRPC-web content type, + returns the default codec name. Otherwise, extracts and returns the codec name from the content type prefix, or returns + the original content type if no known prefix is found. + + """ + if (not web and content_type == GRPC_CONTENT_TYPE_DEFAULT) or ( + web and content_type == GRPC_WEB_CONTENT_TYPE_DEFAULT + ): + return CodecNameType.PROTO + + prefix = GRPC_CONTENT_TYPE_PREFIX if not web else GRPC_WEB_CONTENT_TYPE_PREFIX + + if content_type.startswith(prefix): + return content_type[len(prefix) :] + else: + return content_type + + +def grpc_validate_response_content_type(web: bool, request_codec_name: str, response_content_type: str) -> None: + """Validate that the gRPC response content type matches the expected value based on the request codec and whether gRPC-Web is used. + + Args: + web (bool): Indicates if gRPC-Web is being used. + request_codec_name (str): The name of the codec used in the request (e.g., "proto", "json"). + response_content_type (str): The content type returned in the response. + + Raises: + ConnectError: If the response content type does not match the expected value, with an appropriate error code. + + """ + bare, prefix = GRPC_CONTENT_TYPE_DEFAULT, GRPC_CONTENT_TYPE_PREFIX + if web: + bare, prefix = GRPC_WEB_CONTENT_TYPE_DEFAULT, GRPC_WEB_CONTENT_TYPE_PREFIX + + if response_content_type == prefix + request_codec_name or ( + request_codec_name == CodecNameType.PROTO and response_content_type == bare + ): + return + + expected_content_type = bare + if request_codec_name != CodecNameType.PROTO: + expected_content_type = prefix + request_codec_name + + code = Code.INTERNAL + if response_content_type != bare and not response_content_type.startswith(prefix): + code = Code.UNKNOWN + + raise ConnectError(f"invalid content-type {response_content_type}, expected {expected_content_type}", code) diff --git a/src/connect/protocol_grpc/error_trailer.py b/src/connect/protocol_grpc/error_trailer.py new file mode 100644 index 0000000..2319831 --- /dev/null +++ b/src/connect/protocol_grpc/error_trailer.py @@ -0,0 +1,162 @@ +"""Provides functions to convert between ConnectError and gRPC trailer headers.""" + +import base64 +from urllib.parse import quote, unquote + +from google.protobuf.message import DecodeError +from google.rpc import status_pb2 + +from connect.code import Code +from connect.error import ConnectError, ErrorDetail +from connect.headers import Headers +from connect.protocol import exclude_protocol_headers +from connect.protocol_grpc.constants import ( + GRPC_HEADER_DETAILS, + GRPC_HEADER_MESSAGE, + GRPC_HEADER_STATUS, +) + + +def grpc_error_to_trailer(trailer: Headers, error: ConnectError | None) -> None: + """Convert a ConnectError to gRPC trailer headers. + + Args: + trailer (Headers): The trailer headers dictionary to update with gRPC error information. + error (ConnectError | None): The error to convert. If None, indicates success. + + Side Effects: + Modifies the `trailer` dictionary in-place to include gRPC status, message, and optional details. + + Notes: + - If `error` is None, sets the gRPC status header to "0" (OK). + - If `ConnectError.wire_error` is False, updates the trailer with error metadata excluding protocol headers. + - Serializes error details using protobuf if present, encoding them in base64 for the trailer. + + """ + if error is None: + trailer[GRPC_HEADER_STATUS] = "0" + return + + if not ConnectError.wire_error: + trailer.update(exclude_protocol_headers(error.metadata)) + + status = status_pb2.Status( + code=error.code.value, + message=error.raw_message, + details=error.details_any(), + ) + code = status.code + message = status.message + details_binary = None + + if len(status.details) > 0: + details_binary = status.SerializeToString() + + trailer[GRPC_HEADER_STATUS] = str(code) + trailer[GRPC_HEADER_MESSAGE] = quote(message) + if details_binary: + trailer[GRPC_HEADER_DETAILS] = base64.b64encode(details_binary).decode().rstrip("=") + + +def grpc_error_from_trailer(trailers: Headers) -> ConnectError | None: + """Parse gRPC error information from response trailers and constructs a ConnectError if present. + + Args: + trailers (Headers): The gRPC response trailers containing error information. + + Returns: + ConnectError | None: Returns a ConnectError instance if an error is found in the trailers, + or None if the status code indicates success. + + Raises: + ConnectError: If the grpc-status-details-bin trailer or protobuf error details are invalid. + + The function extracts the gRPC status code, error message, and optional error details from the trailers. + If the status code is missing or invalid, it returns a ConnectError with an appropriate message. + If the status code indicates success ("0"), it returns None. + If error details are present and valid, they are attached to the ConnectError. + + """ + code_header = trailers.get(GRPC_HEADER_STATUS) + if code_header is None: + code = Code.UNKNOWN + if len(trailers) == 0: + code = Code.INTERNAL + + return ConnectError( + f"protocol error: no {GRPC_HEADER_STATUS} header in trailers", + code, + ) + + if code_header == "0": + return None + + try: + code = Code(int(code_header)) + except ValueError: + return ConnectError( + f"protocol error: invalid error code {code_header} in trailers", + ) + + try: + message = unquote(trailers.get(GRPC_HEADER_MESSAGE, "")) + except Exception: + return ConnectError( + f"protocol error: invalid error message {code_header} in trailers", + code=Code.UNKNOWN, + ) + + ret_error = ConnectError( + message, + code, + wire_error=True, + ) + + details_binary_encoded = trailers.get(GRPC_HEADER_DETAILS, None) + if details_binary_encoded and len(details_binary_encoded) > 0: + try: + details_binary = decode_binary_header(details_binary_encoded) + except Exception as e: + raise ConnectError( + f"server returned invalid grpc-status-details-bin trailer: {e}", + code=Code.INTERNAL, + ) from e + + status = status_pb2.Status() + try: + status.ParseFromString(details_binary) + except DecodeError as e: + raise ConnectError( + f"server returned invalid protobuf for error details: {e}", + code=Code.INTERNAL, + ) from e + + for detail in status.details: + ret_error.details.append(ErrorDetail(pb_any=detail)) + + ret_error.code = Code(status.code) + ret_error.raw_message = status.message + + return ret_error + + +def decode_binary_header(data: str) -> bytes: + """Decode a base64-encoded string representing a binary header. + + If the input string's length is not a multiple of 4, it pads the string with '=' characters + to make it valid base64 before decoding. + + Args: + data (str): The base64-encoded string to decode. + + Returns: + bytes: The decoded binary data. + + Raises: + binascii.Error: If the input is not correctly base64-encoded. + + """ + if len(data) % 4: + data += "=" * (-len(data) % 4) + + return base64.b64decode(data, validate=True) diff --git a/src/connect/protocol_grpc/grpc_client.py b/src/connect/protocol_grpc/grpc_client.py new file mode 100644 index 0000000..58ace49 --- /dev/null +++ b/src/connect/protocol_grpc/grpc_client.py @@ -0,0 +1,478 @@ +"""gRPC client implementation for Connect-Python, supporting async streaming and HTTP/2 communication.""" + +import asyncio +import contextlib +import functools +from collections.abc import AsyncIterable, AsyncIterator, Callable, Mapping +from http import HTTPMethod +from typing import Any + +import httpcore +from yarl import URL + +from connect.byte_stream import HTTPCoreResponseAsyncByteStream +from connect.code import Code +from connect.codec import Codec +from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name +from connect.connect import ( + Peer, + Spec, + StreamingClientConn, + StreamType, +) +from connect.error import ConnectError +from connect.headers import Headers, include_request_headers +from connect.protocol import ( + HEADER_CONTENT_TYPE, + HEADER_USER_AGENT, + ProtocolClient, + ProtocolClientParams, + code_from_http_status, +) +from connect.protocol_grpc.constants import ( + DEFAULT_GRPC_USER_AGENT, + GRPC_HEADER_ACCEPT_COMPRESSION, + GRPC_HEADER_COMPRESSION, + GRPC_HEADER_TIMEOUT, + HEADER_X_USER_AGENT, + UNIT_TO_SECONDS, +) +from connect.protocol_grpc.content_type import grpc_content_type_from_codec_name, grpc_validate_response_content_type +from connect.protocol_grpc.error_trailer import grpc_error_from_trailer +from connect.protocol_grpc.marshaler import GRPCMarshaler +from connect.protocol_grpc.unmarshaler import GRPCUnmarshaler +from connect.session import AsyncClientSession +from connect.utils import map_httpcore_exceptions + +EventHook = Callable[..., Any] + + +class GRPCClient(ProtocolClient): + """GRPCClient is a protocol client implementation for gRPC communication, supporting both standard and web environments. + + Attributes: + params (ProtocolClientParams): Configuration parameters for the protocol client, including codec, compression, session, and URL. + _peer (Peer): The peer instance associated with this client, representing the remote endpoint. + web (bool): Indicates whether the client is running in a web environment, affecting header and content-type handling. + + """ + + params: ProtocolClientParams + _peer: Peer + web: bool + + def __init__(self, params: ProtocolClientParams, peer: Peer, web: bool) -> None: + """Initialize the ProtocolClient with the given parameters. + + Args: + params (ProtocolClientParams): The parameters for the protocol client. + peer (Peer): The peer instance to be used. + web (bool): Indicates whether the client is running in a web environment. + + """ + self.params = params + self._peer = peer + self.web = web + + @property + def peer(self) -> Peer: + """Returns the associated Peer object. + + Returns: + Peer: The peer instance associated with this object. + + """ + return self._peer + + def write_request_headers(self, _: StreamType, headers: Headers) -> None: + """Set and modifies HTTP/2 or gRPC request headers based on the stream type, connection parameters, and environment. + + Args: + stream_type (StreamType): The type of stream for which headers are being written. + headers (Headers): The dictionary of headers to be modified or populated. + + Behavior: + - Ensures the 'User-Agent' header is set to the default gRPC user agent if not already present. + - If running in a web environment, also sets the 'X-User-Agent' header. + - Sets the 'Content-Type' header according to the codec name and environment. + - Sets the 'Accept-Encoding' header to indicate supported compression. + - If a specific compression is configured and is not the identity, sets the gRPC compression header. + - If multiple compressions are supported, sets the gRPC accept compression header with the supported values. + - For non-web environments, adds the 'Te: trailers' header required for gRPC. + + Note: + This method mutates the provided headers dictionary in place. + + """ + if headers.get(HEADER_USER_AGENT, None) is None: + headers[HEADER_USER_AGENT] = DEFAULT_GRPC_USER_AGENT + + if self.web and headers.get(HEADER_X_USER_AGENT, None) is None: + headers[HEADER_X_USER_AGENT] = DEFAULT_GRPC_USER_AGENT + + headers[HEADER_CONTENT_TYPE] = grpc_content_type_from_codec_name(self.web, self.params.codec.name) + + headers["Accept-Encoding"] = COMPRESSION_IDENTITY + if self.params.compression_name and self.params.compression_name != COMPRESSION_IDENTITY: + headers[GRPC_HEADER_COMPRESSION] = self.params.compression_name + + if self.params.compressions: + headers[GRPC_HEADER_ACCEPT_COMPRESSION] = ", ".join(c.name for c in self.params.compressions) + + if not self.web: + headers["Te"] = "trailers" + + def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: + """Create and returns a GRPCClientConn instance configured with the provided specification and headers. + + Args: + spec (Spec): The specification object defining the protocol or service interface. + headers (Headers): The request headers to include in the connection. + + Returns: + StreamingClientConn: An initialized gRPC streaming client connection. + + Details: + - Configures the connection with parameters such as session, peer, URL, codec, and compression settings. + - Initializes GRPCMarshaler and GRPCUnmarshaler with appropriate codecs and limits. + - Compression is determined using the provided compression name and available compressions. + + """ + return GRPCClientConn( + web=self.web, + session=self.params.session, + spec=spec, + peer=self.peer, + url=self.params.url, + codec=self.params.codec, + compressions=self.params.compressions, + marshaler=GRPCMarshaler( + codec=self.params.codec, + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), + ), + unmarshaler=GRPCUnmarshaler( + web=self.web, + codec=self.params.codec, + read_max_bytes=self.params.read_max_bytes, + ), + request_headers=headers, + ) + + +class GRPCClientConn(StreamingClientConn): + """GRPCClientConn is a gRPC client connection implementation supporting asynchronous streaming requests and responses over HTTP/2. + + This class manages the lifecycle of a gRPC client connection, including marshaling and unmarshaling messages, handling request and response headers/trailers, managing compression, and supporting event hooks for request/response events. It integrates with an asynchronous HTTP client session and supports cancellation via asyncio events. + + Attributes: + session (AsyncClientSession): The asynchronous client session used for HTTP requests. + _spec (Spec): The protocol or API specification. + _peer (Peer): Information about the remote peer. + url (URL): The endpoint URL for the connection. + codec (Codec | None): Codec for encoding/decoding messages. + compressions (list[Compression]): Supported compression algorithms. + marshaler (GRPCMarshaler): Marshaler for serializing messages. + unmarshaler (GRPCUnmarshaler): Unmarshaler for deserializing messages. + _response_headers (Headers): HTTP response headers. + _response_trailers (Headers): HTTP response trailers. + _request_headers (Headers): HTTP request headers. + receive_trailers (Callable[[], None] | None): Callback to receive trailers after response. + + """ + + web: bool + session: AsyncClientSession + _spec: Spec + _peer: Peer + url: URL + codec: Codec | None + compressions: list[Compression] + marshaler: GRPCMarshaler + unmarshaler: GRPCUnmarshaler + _response_headers: Headers + _response_trailers: Headers + _request_headers: Headers + receive_trailers: Callable[[], None] | None + + def __init__( + self, + web: bool, + session: AsyncClientSession, + spec: Spec, + peer: Peer, + url: URL, + codec: Codec | None, + compressions: list[Compression], + request_headers: Headers, + marshaler: GRPCMarshaler, + unmarshaler: GRPCUnmarshaler, + event_hooks: None | (Mapping[str, list[EventHook]]) = None, + ) -> None: + """Initialize a new instance of the class. + + Args: + web (bool): Indicates if the connection is for a web environment. + session (AsyncClientSession): The asynchronous client session to use for requests. + spec (Spec): The specification object describing the protocol or API. + peer (Peer): The peer information for the connection. + url (URL): The URL endpoint for the connection. + codec (Codec | None): The codec to use for encoding/decoding messages, or None. + compressions (list[Compression]): List of supported compression algorithms. + request_headers (Headers): Headers to include in outgoing requests. + marshaler (GRPCMarshaler): The marshaler for serializing messages. + unmarshaler (GRPCUnmarshaler): The unmarshaler for deserializing messages. + event_hooks (None | Mapping[str, list[EventHook]], optional): Optional mapping of event hooks for "request" and "response" events. Defaults to None. + + """ + event_hooks = {} if event_hooks is None else event_hooks + + self.web = web + self.session = session + self._spec = spec + self._peer = peer + self.url = url + self.codec = codec + self.compressions = compressions + self.marshaler = marshaler + self.unmarshaler = unmarshaler + self._response_headers = Headers() + self._response_trailers = Headers() + self._request_headers = request_headers + + self._event_hooks = { + "request": list(event_hooks.get("request", [])), + "response": list(event_hooks.get("response", [])), + } + + @property + def spec(self) -> Spec: + """Return the specification details.""" + return self._spec + + @property + def peer(self) -> Peer: + """Return the peer information.""" + raise NotImplementedError() + + async def receive(self, message: Any, abort_event: asyncio.Event | None) -> AsyncIterator[Any]: + """Receives a message and processes it.""" + trailer_received = False + + async for obj, end in self.unmarshaler.unmarshal(message): + if abort_event and abort_event.is_set(): + raise ConnectError("receive operation aborted", Code.CANCELED) + + if end: + if trailer_received: + raise ConnectError("received extra end stream trailer", Code.INVALID_ARGUMENT) + + trailer_received = True + if self.unmarshaler.web_trailers is None: + raise ConnectError("trailer not received", Code.INVALID_ARGUMENT) + + continue + + if trailer_received: + raise ConnectError("protocol error: received extra message after trailer", Code.INVALID_ARGUMENT) + + yield obj + + if callable(self.receive_trailers): + self.receive_trailers() + + if self.unmarshaler.bytes_read == 0 and len(self.response_trailers) == 0: + self.response_trailers.update(self._response_headers) + del self._response_headers[HEADER_CONTENT_TYPE] + + server_error = grpc_error_from_trailer(self.response_trailers) + if server_error: + server_error.metadata = self.response_headers.copy() + raise server_error + + server_error = grpc_error_from_trailer(self.response_trailers) + if server_error: + server_error.metadata = self.response_headers.copy() + server_error.metadata.update(self.response_trailers) + raise server_error + + def _receive_trailers(self, response: httpcore.Response) -> None: + if self.web: + trailers = self.unmarshaler.web_trailers + if trailers is not None: + self._response_trailers.update(trailers) + + else: + if "trailing_headers" not in response.extensions: + return + + trailers = response.extensions["trailing_headers"] + self._response_trailers.update(Headers(trailers)) + + @property + def request_headers(self) -> Headers: + """Return the request headers.""" + return self._request_headers + + async def send( + self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None + ) -> None: + """Send a gRPC request asynchronously using HTTP/2 via httpcore, handling streaming messages, timeouts, and abort events. + + Args: + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled and sent as the request body. + timeout (float | None): Optional timeout in seconds for the request. If provided, sets the gRPC timeout header. + abort_event (asyncio.Event | None): Optional asyncio event that, if set, will abort the request and raise a cancellation error. + + Raises: + ConnectError: If the request is aborted before or during execution, or if an error occurs during the HTTP request. + + Side Effects: + - Invokes registered request and response event hooks. + - Sets up the response stream and trailers for further processing. + - Validates the HTTP response. + + """ + extensions = {} + if timeout: + extensions["timeout"] = {"read": timeout} + self._request_headers[GRPC_HEADER_TIMEOUT] = grpc_encode_timeout(timeout) + + content_iterator = self.marshaler.marshal(messages) + + request = httpcore.Request( + method=HTTPMethod.POST, + url=httpcore.URL( + scheme=self.url.scheme, + host=self.url.host or "", + port=self.url.port, + target=self.url.raw_path, + ), + headers=list( + include_request_headers( + headers=self._request_headers, url=self.url, content=content_iterator, method=HTTPMethod.POST + ).items() + ), + content=content_iterator, + extensions=extensions, + ) + + for hook in self._event_hooks["request"]: + hook(request) + + with map_httpcore_exceptions(): + if not abort_event: + response = await self.session.pool.handle_async_request(request) + else: + request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) + abort_task = asyncio.create_task(abort_event.wait()) + + done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) + + if abort_task in done: + request_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await request_task + + raise ConnectError("request aborted", Code.CANCELED) + + abort_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await abort_task + + response = await request_task + + for hook in self._event_hooks["response"]: + hook(response) + + assert isinstance(response.stream, AsyncIterable) + self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(aiterator=response.stream) + self.receive_trailers = functools.partial(self._receive_trailers, response) + + await self._validate_response(response) + + async def _validate_response(self, response: httpcore.Response) -> None: + response_headers = Headers(response.headers) + if response.status != 200: + raise ConnectError( + f"HTTP {response.status}", + code_from_http_status(response.status), + ) + + grpc_validate_response_content_type( + self.web, + self.marshaler.codec.name if self.marshaler.codec else "", + response_headers.get(HEADER_CONTENT_TYPE, ""), + ) + + compression = response_headers.get(GRPC_HEADER_COMPRESSION, None) + if compression and compression != COMPRESSION_IDENTITY: + self.unmarshaler.compression = get_compresion_from_name(compression, self.compressions) + + self._response_headers.update(response_headers) + + @property + def response_headers(self) -> Headers: + """Return the response headers.""" + return self._response_headers + + @property + def response_trailers(self) -> Headers: + """Return response trailers.""" + return self._response_trailers + + def on_request_send(self, fn: EventHook) -> None: + """Register a callback function to be invoked when a request is sent. + + Args: + fn (EventHook): The callback function to be added to the "request" event hook. + + Returns: + None + + """ + self._event_hooks["request"].append(fn) + + async def aclose(self) -> None: + """Asynchronously closes the underlying unmarshaler resource. + + This method should be called to properly release any resources held by the unmarshaler, + such as open network connections or file handles, when they are no longer needed. + """ + await self.unmarshaler.aclose() + + +def grpc_encode_timeout(timeout: float) -> str: + """Encode a timeout value (in seconds) into the gRPC timeout format string. + + The gRPC timeout format is a decimal number with a time unit suffix, where the unit can be: + - 'H' for hours + - 'M' for minutes + - 'S' for seconds + - 'm' for milliseconds + - 'u' for microseconds + - 'n' for nanoseconds + + If the timeout is less than or equal to zero, returns "0n". + + Args: + timeout (float): The timeout value in seconds. + + Returns: + str: The timeout encoded as a gRPC timeout string. + + """ + if timeout <= 0: + return "0n" + + grpc_timeout_max_value = 10**8 + + _units = dict(sorted(UNIT_TO_SECONDS.items(), key=lambda item: item[1])) + for unit, size in _units.items(): + if timeout < size * grpc_timeout_max_value: + value = int(timeout / size) + return f"{value}{unit}" + + value = int(timeout / 3600.0) + return f"{value}H" diff --git a/src/connect/protocol_grpc/grpc_handler.py b/src/connect/protocol_grpc/grpc_handler.py new file mode 100644 index 0000000..98d573d --- /dev/null +++ b/src/connect/protocol_grpc/grpc_handler.py @@ -0,0 +1,437 @@ +"""gRPC and gRPC-Web protocol handler implementation for Connect Python server.""" + +from collections.abc import AsyncIterable, AsyncIterator +from http import HTTPMethod +from typing import Any + +from connect.code import Code +from connect.compression import COMPRESSION_IDENTITY +from connect.connect import ( + Address, + Peer, + Spec, + StreamingHandlerConn, +) +from connect.error import ConnectError +from connect.headers import Headers +from connect.protocol import ( + HEADER_CONTENT_TYPE, + PROTOCOL_GRPC, + ProtocolHandler, + ProtocolHandlerParams, + negotiate_compression, +) +from connect.protocol_grpc.constants import ( + GRPC_ALLOWED_METHODS, + GRPC_HEADER_ACCEPT_COMPRESSION, + GRPC_HEADER_COMPRESSION, + GRPC_HEADER_TIMEOUT, + MAX_HOURS, + RE_TIMEOUT, + UNIT_TO_SECONDS, +) +from connect.protocol_grpc.content_type import grpc_codec_from_content_type +from connect.protocol_grpc.error_trailer import grpc_error_to_trailer +from connect.protocol_grpc.marshaler import GRPCMarshaler +from connect.protocol_grpc.unmarshaler import GRPCUnmarshaler +from connect.request import Request +from connect.streaming_response import StreamingResponse +from connect.utils import aiterate +from connect.writer import ServerResponseWriter + + +class GRPCHandler(ProtocolHandler): + """GRPCHandler is a protocol handler for gRPC and gRPC-Web requests. + + This class implements the ProtocolHandler interface to handle gRPC protocol requests, + including negotiation of compression, codec selection, and connection management for + both standard gRPC and gRPC-Web. It supports content type negotiation, payload handling, + and manages the lifecycle of a gRPC connection, including streaming and non-streaming + requests. + + Attributes: + params (ProtocolHandlerParams): Configuration parameters for the handler, including codecs and compressions. + web (bool): Indicates if the handler is for gRPC-Web. + accept (list[str]): List of accepted content types. + + """ + + params: ProtocolHandlerParams + web: bool + accept: list[str] + + def __init__(self, params: ProtocolHandlerParams, web: bool, accept: list[str]) -> None: + """Initialize the ProtocolHandler with the given parameters. + + Args: + params (ProtocolHandlerParams): The parameters required for the protocol handler. + web (bool): Indicates whether the handler is for web usage. + accept (list[str]): A list of accepted content types. + + Returns: + None + + """ + self.params = params + self.web = web + self.accept = accept + + @property + def methods(self) -> list[HTTPMethod]: + """Returns a list of allowed HTTP methods for gRPC protocol. + + Returns: + list[HTTPMethod]: A list containing the HTTP methods permitted for gRPC communication. + + """ + return GRPC_ALLOWED_METHODS + + def content_types(self) -> list[str]: + """Return a list of accepted content types. + + Returns: + list[str]: A list of MIME types that are accepted. + + """ + return self.accept + + def can_handle_payload(self, _: Request, content_type: str) -> bool: + """Determine if the given content type is supported by this handler. + + Args: + _ (Request): The request object (unused). + content_type (str): The MIME type of the payload to check. + + Returns: + bool: True if the content type is accepted, False otherwise. + + """ + return content_type in self.accept + + async def conn( + self, + request: Request, + response_headers: Headers, + response_trailers: Headers, + writer: ServerResponseWriter, + ) -> StreamingHandlerConn | None: + """Handle a connection request. + + Args: + request (Request): The incoming request object. + response_headers (Headers): The headers to be sent in the response. + response_trailers (Headers): The trailers to be sent in the response. + writer (ServerResponseWriter): The writer used to send the response. + is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. + + Returns: + StreamingHandlerConn | None: The connection handler or None if not implemented. + + """ + content_encoding = request.headers.get(GRPC_HEADER_COMPRESSION) + accept_encoding = request.headers.get(GRPC_HEADER_ACCEPT_COMPRESSION) + + request_compression, response_compression, error = negotiate_compression( + self.params.compressions, content_encoding, accept_encoding + ) + + response_headers[HEADER_CONTENT_TYPE] = request.headers.get(HEADER_CONTENT_TYPE, "") + response_headers[GRPC_HEADER_ACCEPT_COMPRESSION] = f"{', '.join(c.name for c in self.params.compressions)}" + if response_compression and response_compression.name != COMPRESSION_IDENTITY: + response_headers[GRPC_HEADER_COMPRESSION] = response_compression.name + + codec_name = grpc_codec_from_content_type(self.web, request.headers.get(HEADER_CONTENT_TYPE, "")) + codec = self.params.codecs.get(codec_name) + protocol_name = PROTOCOL_GRPC if not self.web else PROTOCOL_GRPC + "-web" + + peer = Peer( + address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, + protocol=protocol_name, + query=request.query_params, + ) + + conn = GRPCHandlerConn( + web=self.web, + writer=writer, + spec=self.params.spec, + peer=peer, + marshaler=GRPCMarshaler( + codec, + response_compression, + self.params.compress_min_bytes, + self.params.send_max_bytes, + ), + unmarshaler=GRPCUnmarshaler( + self.web, + codec, + self.params.read_max_bytes, + request.stream(), + request_compression, + ), + request_headers=Headers(request.headers, encoding="latin-1"), + response_headers=response_headers, + response_trailers=response_trailers, + ) + + if error: + await conn.send_error(error) + return None + + return conn + + +class GRPCHandlerConn(StreamingHandlerConn): + """GRPCHandlerConn is a handler class for managing gRPC protocol connections within a streaming server context. + + This class encapsulates the logic for handling gRPC requests and responses, including marshaling and unmarshaling messages, + managing request and response headers/trailers, handling timeouts, and enforcing protocol-specific constraints for unary and streaming operations. + + Attributes: + _spec (Spec): The specification object describing the protocol or service. + _peer (Peer): The peer information for the current connection. + _request_headers (Headers): The headers received with the request. + _response_headers (Headers): The headers to include in the response. + _response_trailers (Headers): The trailers to include in the response. + _is_streaming (bool): Indicates if the connection is streaming. + + """ + + web: bool + _spec: Spec + _peer: Peer + writer: ServerResponseWriter + marshaler: GRPCMarshaler + unmarshaler: GRPCUnmarshaler + _request_headers: Headers + _response_headers: Headers + _response_trailers: Headers + + def __init__( + self, + web: bool, + writer: ServerResponseWriter, + spec: Spec, + peer: Peer, + marshaler: GRPCMarshaler, + unmarshaler: GRPCUnmarshaler, + request_headers: Headers, + response_headers: Headers, + response_trailers: Headers | None = None, + ) -> None: + """Initialize a new instance of the class. + + Args: + web (bool): Indicates if the connection is for a web environment. + writer (ServerResponseWriter): The writer used to send responses to the client. + spec (Spec): The specification object describing the protocol or service. + peer (Peer): The peer information for the current connection. + marshaler (GRPCMarshaler): The marshaler used to serialize response messages. + unmarshaler (GRPCUnmarshaler): The unmarshaler used to deserialize request messages. + request_headers (Headers): The headers received with the request. + response_headers (Headers): The headers to include in the response. + response_trailers (Headers | None, optional): The trailers to include in the response. Defaults to None. + is_streaming (bool, optional): Indicates if the connection is streaming. Defaults to False. + + """ + self.web = web + self.writer = writer + self._spec = spec + self._peer = peer + self.marshaler = marshaler + self.unmarshaler = unmarshaler + self._request_headers = request_headers + self._response_headers = response_headers + self._response_trailers = response_trailers if response_trailers is not None else Headers() + + def parse_timeout(self) -> float | None: + """Parse the gRPC timeout value from the request headers and returns it as seconds. + + Returns: + float | None: The timeout value in seconds if present and valid, otherwise None. + + Raises: + ConnectError: If the timeout value is present but invalid or too long. + + Notes: + - The timeout is extracted from the gRPC header and must match the expected format. + - If the timeout unit is hours and exceeds the maximum allowed, None is returned. + + """ + timeout = self._request_headers.get(GRPC_HEADER_TIMEOUT) + if not timeout: + return None + + m = RE_TIMEOUT.match(timeout) + if m is None: + raise ConnectError(f"protocol error: invalid grpc timeout value: {timeout}") + + num_str, unit = m.groups() + num = int(num_str) + if num > 99_999_999: + raise ConnectError(f"protocol error: timeout {timeout!r} is too long") + + if unit == "H" and num > MAX_HOURS: + return None + + seconds = num * UNIT_TO_SECONDS[unit] + return seconds + + @property + def spec(self) -> Spec: + """Returns the specification object associated with this instance. + + Returns: + Spec: The specification object. + + """ + return self._spec + + @property + def peer(self) -> Peer: + """Returns the associated Peer object. + + Returns: + Peer: The peer instance associated with this object. + + """ + return self._peer + + async def receive(self, message: Any) -> AsyncIterator[Any]: + """Receives a message and processes it. + + Args: + message (Any): The message to be received and processed. + + Returns: + AsyncIterator[Any]: An async iterator yielding message(s). For non-streaming operations, + this will yield exactly one item. + + """ + async for obj, _ in self.unmarshaler.unmarshal(message): + yield obj + + @property + def request_headers(self) -> Headers: + """Returns the headers associated with the current request. + + Returns: + Headers: The headers of the request. + + """ + return self._request_headers + + async def send(self, messages: AsyncIterable[Any]) -> None: + """Send message(s) by marshaling them into bytes. + + Args: + messages (AsyncIterable[Any]): The message(s) to be sent. For unary operations, + this should be an iterable with a single item. + + Returns: + None + + """ + if self.web: + await self.writer.write( + StreamingResponse( + content=self._send_messages(messages), + headers=self.response_headers, + status_code=200, + ) + ) + else: + await self.writer.write( + StreamingResponse( + content=self._send_messages(messages), + headers=self.response_headers, + trailers=self.response_trailers, + status_code=200, + ) + ) + + @property + def response_headers(self) -> Headers: + """Returns the response headers associated with the current request. + + Returns: + Headers: The headers returned in the response. + + """ + return self._response_headers + + @property + def response_trailers(self) -> Headers: + """Returns the response trailers as headers. + + Response trailers are additional metadata sent by the server after the response body, + typically used in gRPC and HTTP/2 protocols. + + Returns: + Headers: The response trailers associated with the current response. + + """ + return self._response_trailers + + async def _send_messages(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: + """Asynchronously sends marshaled messages and yields them as byte streams. + + Args: + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled and sent. + + Yields: + bytes: Marshaled message bytes, and optionally marshaled web trailers if in web mode. + + Raises: + ConnectError: If an error occurs during marshaling or sending messages, a ConnectError is set and handled. + + Notes: + - Errors encountered during message marshaling are converted to ConnectError and added to response trailers. + - If running in web mode (`self.web` is True), marshaled web trailers are yielded at the end. + + """ + error: ConnectError | None = None + try: + async for msg in self.marshaler.marshal(messages): + yield msg + except Exception as e: + error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) + finally: + grpc_error_to_trailer(self.response_trailers, error) + + if self.web: + body = await self.marshaler.marshal_web_trailers(self.response_trailers) + yield body + + async def send_error(self, error: ConnectError) -> None: + """Send an error response over gRPC by converting the provided ConnectError into gRPC trailers. + + Args: + error (ConnectError): The error to be sent as a gRPC trailer. + + Returns: + None + + This method updates the response trailers with the error information and writes a streaming response + with the appropriate headers and trailers to the client. + + """ + grpc_error_to_trailer(self.response_trailers, error) + if self.web: + body = await self.marshaler.marshal_web_trailers(self.response_trailers) + + await self.writer.write( + StreamingResponse( + content=aiterate([body]), + headers=self.response_headers, + status_code=200, + ) + ) + else: + await self.writer.write( + StreamingResponse( + content=[], + headers=self.response_headers, + trailers=self.response_trailers, + status_code=200, + ) + ) diff --git a/src/connect/protocol_grpc/grpc_protocol.py b/src/connect/protocol_grpc/grpc_protocol.py new file mode 100644 index 0000000..56eff49 --- /dev/null +++ b/src/connect/protocol_grpc/grpc_protocol.py @@ -0,0 +1,91 @@ +"""Protocol implementation for handling gRPC and gRPC-Web requests.""" + +from connect.codec import CodecNameType +from connect.connect import ( + Address, + Peer, +) +from connect.protocol import ( + PROTOCOL_GRPC, + PROTOCOL_GRPC_WEB, + Protocol, + ProtocolClient, + ProtocolClientParams, + ProtocolHandler, + ProtocolHandlerParams, +) +from connect.protocol_grpc.constants import ( + GRPC_CONTENT_TYPE_DEFAULT, + GRPC_CONTENT_TYPE_PREFIX, + GRPC_WEB_CONTENT_TYPE_DEFAULT, + GRPC_WEB_CONTENT_TYPE_PREFIX, +) +from connect.protocol_grpc.grpc_client import GRPCClient +from connect.protocol_grpc.grpc_handler import GRPCHandler + + +class ProtocolGRPC(Protocol): + """ProtocolGRPC is a protocol implementation for handling gRPC and gRPC-Web requests. + + Attributes: + web (bool): Indicates whether to use gRPC-Web (True) or standard gRPC (False). + + """ + + def __init__(self, web: bool) -> None: + """Initialize the instance. + + Args: + web (bool): Indicates whether the instance is for web usage. + + """ + self.web = web + + def handler(self, params: ProtocolHandlerParams) -> ProtocolHandler: + """Create and returns a GRPCHandler instance configured with appropriate content types based on the provided parameters. + + Args: + params (ProtocolHandlerParams): The parameters containing codec information and other handler configuration. + + Returns: + ProtocolHandler: An instance of GRPCHandler initialized with the correct content types for gRPC or gRPC-Web. + + Behavior: + - Determines the default and prefix content types based on whether gRPC-Web is enabled. + - Constructs a list of supported content types from the available codecs. + - Adds the bare content type if the PROTO codec is present. + - Returns a GRPCHandler with the computed content types. + + """ + bare, prefix = GRPC_CONTENT_TYPE_DEFAULT, GRPC_CONTENT_TYPE_PREFIX + if self.web: + bare, prefix = GRPC_WEB_CONTENT_TYPE_DEFAULT, GRPC_WEB_CONTENT_TYPE_PREFIX + + content_types: list[str] = [] + for name in params.codecs.names(): + content_types.append(prefix + name) + + if params.codecs.get(CodecNameType.PROTO): + content_types.append(bare) + + return GRPCHandler(params, self.web, content_types) + + def client(self, params: ProtocolClientParams) -> ProtocolClient: + """Create and return a GRPCClient instance. + + Args: + params (ProtocolClientParams): The parameters required to initialize the client. + + Returns: + ProtocolClient: An instance of GRPCClient. + + """ + peer = Peer( + address=Address(host=params.url.host or "", port=params.url.port or 80), + protocol=PROTOCOL_GRPC, + query={}, + ) + if self.web: + peer.protocol = PROTOCOL_GRPC_WEB + + return GRPCClient(params, peer, self.web) diff --git a/src/connect/protocol_grpc/marshaler.py b/src/connect/protocol_grpc/marshaler.py new file mode 100644 index 0000000..3eb9740 --- /dev/null +++ b/src/connect/protocol_grpc/marshaler.py @@ -0,0 +1,62 @@ +"""Marshaler for encoding messages into the gRPC wire format, including gRPC-Web trailer support.""" + +from connect.codec import Codec +from connect.compression import Compression +from connect.envelope import EnvelopeFlags, EnvelopeWriter +from connect.headers import Headers + + +class GRPCMarshaler(EnvelopeWriter): + """GRPCMarshaler is responsible for marshaling messages into the gRPC wire format. + + Args: + codec (Codec | None): The codec used for encoding/decoding messages. + compression (Compression | None): The compression algorithm to use, if any. + compress_min_bytes (int): Minimum message size in bytes before compression is applied. + send_max_bytes (int): Maximum allowed size of a message to send. + + Methods: + marshal(messages: AsyncIterable[bytes]) -> AsyncIterator[bytes]: + Asynchronously marshals a stream of message bytes into the gRPC wire format. + Yields marshaled message bytes ready for transmission. + + """ + + def __init__( + self, + codec: Codec | None, + compression: Compression | None, + compress_min_bytes: int, + send_max_bytes: int, + ) -> None: + """Initialize the protocol with the specified configuration. + + Args: + codec (Codec | None): The codec to use for encoding/decoding messages, or None for default. + compression (Compression | None): The compression algorithm to use, or None for no compression. + compress_min_bytes (int): The minimum number of bytes before compression is applied. + send_max_bytes (int): The maximum number of bytes allowed to send in a single message. + + Returns: + None + + """ + super().__init__(codec, compression, compress_min_bytes, send_max_bytes) + + async def marshal_web_trailers(self, trailers: Headers) -> bytes: + """Serialize HTTP trailer headers into a gRPC-Web trailer envelope. + + Args: + trailers (Headers): A dictionary-like object containing HTTP trailer headers. + + Returns: + bytes: The serialized gRPC-Web trailer envelope containing the trailer headers. + + """ + lines = [] + for key, value in trailers.items(): + lines.append(f"{key}: {value}\r\n") + + env = self.write_envelope("".join(lines).encode(), EnvelopeFlags.trailer) + + return env.encode() diff --git a/src/connect/protocol_grpc/unmarshaler.py b/src/connect/protocol_grpc/unmarshaler.py new file mode 100644 index 0000000..222a02d --- /dev/null +++ b/src/connect/protocol_grpc/unmarshaler.py @@ -0,0 +1,105 @@ +"""Module for gRPC message unmarshaling using EnvelopeReader and related utilities.""" + +from collections.abc import AsyncIterable, AsyncIterator +from copy import copy +from typing import Any + +from connect.codec import Codec +from connect.compression import Compression +from connect.envelope import EnvelopeFlags, EnvelopeReader +from connect.error import ConnectError +from connect.headers import Headers + + +class GRPCUnmarshaler(EnvelopeReader): + """GRPCUnmarshaler is a specialized EnvelopeReader for handling gRPC message unmarshaling. + + Args: + codec (Codec | None): The codec used for decoding messages. + read_max_bytes (int): The maximum number of bytes to read from the stream. + stream (AsyncIterable[bytes] | None, optional): The asynchronous byte stream to read messages from. + compression (Compression | None, optional): Compression algorithm to use for decompressing messages. + + Methods: + async unmarshal(message: Any) -> AsyncIterator[Any]: + Asynchronously unmarshals the given message, yielding each decoded object. + Iterates over the results of the internal _unmarshal method, yielding only the object part of each tuple. + + """ + + web: bool + _web_trailers: Headers | None + + def __init__( + self, + web: bool, + codec: Codec | None, + read_max_bytes: int, + stream: AsyncIterable[bytes] | None = None, + compression: Compression | None = None, + ) -> None: + """Initialize the protocol gRPC handler. + + Args: + web (bool): Indicates if the connection is for a web environment. + codec (Codec | None): The codec to use for encoding/decoding messages. Can be None. + read_max_bytes (int): The maximum number of bytes to read from the stream. + stream (AsyncIterable[bytes] | None, optional): An asynchronous iterable stream of bytes. Defaults to None. + compression (Compression | None, optional): The compression method to use. Defaults to None. + + """ + super().__init__(codec, read_max_bytes, stream, compression) + self.web = web + self._web_trailers = None + + async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: + """Asynchronously unmarshals a given message and yields each resulting object. + + Args: + message (Any): The message to be unmarshaled. + + Yields: + Any: Each object obtained from unmarshaling the message. + + """ + async for obj, end in super().unmarshal(message): + if end: + env = self.last + if not env: + raise ConnectError("protocol error: empty envelope") + + data = copy(env.data) + env.data = b"" + + if not (self.web and env.is_set(EnvelopeFlags.trailer)): + raise ConnectError( + f"protocol error: invalid envelope flags: {env.flags}", + ) + + trailers = Headers() + lines = data.decode("utf-8").splitlines() + for line in lines: + if line == "": + continue + + name, value = line.split(":", 1) + name = name.strip().lower() + value = value.strip() + if name in trailers: + trailers[name] += "," + value + else: + trailers[name] = value + + self._web_trailers = trailers + + yield obj, end + + @property + def web_trailers(self) -> Headers | None: + """Return the trailers received in the last envelope. + + Returns: + Headers | None: The trailers received in the last envelope, or None if no trailers were received. + + """ + return self._web_trailers From 3fb235ab3350f68cb01b6bd1efa67c986b3c412f Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Tue, 13 May 2025 20:22:03 +0900 Subject: [PATCH 3/3] all: rename functions and files --- src/connect/asgi_helpers/__init__.py | 0 src/connect/asgi_helpers/utils.py | 77 +++++++ src/connect/byte_stream.py | 56 ----- src/connect/connect.py | 3 +- src/connect/content_stream.py | 158 ++++++++++++++ src/connect/handler.py | 5 +- src/connect/middleware.py | 2 +- src/connect/protocol.py | 2 +- .../protocol_connect/connect_client.py | 6 +- .../protocol_connect/connect_handler.py | 5 +- src/connect/protocol_grpc/grpc_client.py | 4 +- src/connect/protocol_grpc/grpc_handler.py | 4 +- src/connect/response.py | 147 +++++++++++++ src/connect/{writer.py => response_writer.py} | 0 src/connect/streaming_response.py | 150 ------------- src/connect/utils.py | 206 ------------------ 16 files changed, 397 insertions(+), 428 deletions(-) create mode 100644 src/connect/asgi_helpers/__init__.py create mode 100644 src/connect/asgi_helpers/utils.py delete mode 100644 src/connect/byte_stream.py create mode 100644 src/connect/content_stream.py rename src/connect/{writer.py => response_writer.py} (100%) delete mode 100644 src/connect/streaming_response.py diff --git a/src/connect/asgi_helpers/__init__.py b/src/connect/asgi_helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/connect/asgi_helpers/utils.py b/src/connect/asgi_helpers/utils.py new file mode 100644 index 0000000..2002c93 --- /dev/null +++ b/src/connect/asgi_helpers/utils.py @@ -0,0 +1,77 @@ +"""Utility functions for ASGI helpers, including request/response decorators and route path extraction.""" + +import typing +from collections.abc import ( + Awaitable, + Callable, +) + +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import ASGIApp, Receive, Scope, Send + +from connect.utils import is_async_callable, run_in_threadpool + + +def request_response(func: Callable[[Request], Awaitable[Response] | Response]) -> ASGIApp: + """Convert a request handler function into an ASGI application. + + This decorator takes a function that handles a request and returns a response, + and wraps it into an ASGI application callable. The handler function can be either + synchronous or asynchronous. + + Args: + func (Callable[[Request], Awaitable[Response] | Response]): The request handler function. + It can be a synchronous function returning a Response or an asynchronous function + returning an Awaitable of Response. + + Returns: + ASGIApp: An ASGI application callable that can be used to handle ASGI requests. + + """ + + async def async_func(request: Request) -> Response: + if is_async_callable(func): + return await func(request) + else: + return typing.cast(Response, await run_in_threadpool(func, request)) + + f: Callable[[Request], Awaitable[Response]] = async_func + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + request = Request(scope, receive, send) + response = await f(request) + await response(scope, receive, send) + + return app + + +def get_route_path(scope: Scope) -> str: + """Extract the route path from the given scope. + + Args: + scope (Scope): The scope dictionary containing the request information. + + Returns: + str: The extracted route path. If a root path is specified in the scope, + the function returns the path relative to the root path. If the path + does not start with the root path or if the path is equal to the root + path, the function returns the original path or an empty string, + respectively. + + """ + path: str = scope["path"] + root_path = scope.get("root_path", "") + if not root_path: + return path + + if not path.startswith(root_path): + return path + + if path == root_path: + return "" + + if path[len(root_path)] == "/": + return path[len(root_path) :] + + return path diff --git a/src/connect/byte_stream.py b/src/connect/byte_stream.py deleted file mode 100644 index 5033d9d..0000000 --- a/src/connect/byte_stream.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Asynchronous byte stream utilities for HTTP core response handling.""" - -from collections.abc import ( - AsyncIterable, - AsyncIterator, -) - -from connect.utils import ( - AsyncByteStream, - get_acallable_attribute, - map_httpcore_exceptions, -) - - -class HTTPCoreResponseAsyncByteStream(AsyncByteStream): - """An asynchronous byte stream for reading and writing byte chunks.""" - - aiterator: AsyncIterable[bytes] | None - _closed: bool - - def __init__( - self, - aiterator: AsyncIterable[bytes] | None = None, - ) -> None: - """Initialize the protocol connect instance. - - Args: - aiterator (AsyncIterable[bytes] | None): An optional asynchronous iterable of bytes. - - Returns: - None - - """ - self.aiterator = aiterator - self._closed = False - - async def __aiter__(self) -> AsyncIterator[bytes]: - """Asynchronous iterator method to read byte chunks from the stream.""" - if self.aiterator: - try: - with map_httpcore_exceptions(): - async for chunk in self.aiterator: - yield chunk - except BaseException as exc: - await self.aclose() - raise exc - - async def aclose(self) -> None: - """Asynchronously close the stream.""" - if not self._closed and self.aiterator: - aclose = get_acallable_attribute(self.aiterator, "aclose") - if not aclose: - return - - with map_httpcore_exceptions(): - await aclose() diff --git a/src/connect/connect.py b/src/connect/connect.py index d628824..81b36fc 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -10,10 +10,11 @@ from pydantic import BaseModel from connect.code import Code +from connect.content_stream import AsyncDataStream from connect.error import ConnectError from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel -from connect.utils import AsyncDataStream, aiterate, get_acallable_attribute, get_callable_attribute +from connect.utils import aiterate, get_acallable_attribute, get_callable_attribute class StreamType(Enum): diff --git a/src/connect/content_stream.py b/src/connect/content_stream.py new file mode 100644 index 0000000..7e25f5d --- /dev/null +++ b/src/connect/content_stream.py @@ -0,0 +1,158 @@ +"""Asynchronous byte stream utilities for HTTP core response handling.""" + +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, +) + +from connect.utils import ( + get_acallable_attribute, + map_httpcore_exceptions, +) + + +class AsyncByteStream(AsyncIterable[bytes]): + """An abstract base class for asynchronous byte streams. + + This class defines the interface for an asynchronous byte stream, which + includes methods for iterating over the stream and closing it. + + """ + + async def __aiter__(self) -> AsyncIterator[bytes]: + """Asynchronous iterator method. + + This method should be implemented to provide asynchronous iteration + over the object. It must return an asynchronous iterator that yields + bytes. + + Raises: + NotImplementedError: If the method is not implemented. + + """ + raise NotImplementedError("The '__aiter__' method must be implemented.") # pragma: no cover + yield b"" + + async def aclose(self) -> None: + """Asynchronously close the byte stream.""" + pass + + +class BoundAsyncStream(AsyncByteStream): + """An asynchronous byte stream wrapper that binds to an existing async iterable of bytes. + + This class provides an asynchronous iterator interface for reading byte chunks from the given stream, + and ensures proper resource cleanup by closing the underlying stream when needed. + + Args: + stream (AsyncIterable[bytes]): The asynchronous iterable byte stream to wrap. + + Attributes: + stream (AsyncIterable[bytes]): The wrapped asynchronous byte stream. + _closed (bool): Indicates whether the stream has been closed. + + """ + + _stream: AsyncIterable[bytes] + _closed: bool + + def __init__(self, stream: AsyncIterable[bytes]) -> None: + """Initialize the object with an asynchronous iterable stream of bytes. + + Args: + stream (AsyncIterable[bytes]): An asynchronous iterable that yields bytes. + + """ + self._stream = stream + self._closed = False + + async def __aiter__(self) -> AsyncIterator[bytes]: + """Asynchronous iterator method to read byte chunks from the stream.""" + try: + with map_httpcore_exceptions(): + async for chunk in self._stream: + yield chunk + except BaseException as exc: + await self.aclose() + raise exc + + async def aclose(self) -> None: + """Asynchronously close the stream.""" + if self._closed: + return + + self._closed = True + with map_httpcore_exceptions(): + aclose = get_acallable_attribute(self._stream, "aclose") + if aclose: + await aclose() + + +class AsyncDataStream[T]: + """An asynchronous data stream wrapper that provides iteration and cleanup functionality. + + Type Parameters: + T: The type of items yielded by the stream. + + Attributes: + _stream (AsyncIterable[T]): The underlying asynchronous iterable data stream. + aclose_func (Callable[..., Awaitable[None]] | None): Optional asynchronous cleanup function to be called on close. + + """ + + _stream: AsyncIterable[T] + _aclose_func: Callable[..., Awaitable[None]] | None + _closed: bool + + def __init__(self, stream: AsyncIterable[T], aclose_func: Callable[..., Awaitable[None]] | None = None) -> None: + """Initialize the object with an asynchronous iterable stream and an optional asynchronous close function. + + Args: + stream (AsyncIterable[T]): The asynchronous iterable stream to be wrapped. + aclose_func (Callable[..., Awaitable[None]], optional): An optional asynchronous function to be called when closing the stream. Defaults to None. + + """ + self._stream = stream + self._aclose_func = aclose_func + self._closed = False + + async def __aiter__(self) -> AsyncIterator[T]: + """Asynchronously iterates over the underlying stream, yielding each part. + + Yields: + T: The next part from the stream. + + Raises: + Propagates any exception raised during iteration after ensuring the stream is closed. + + """ + try: + async for part in self._stream: + yield part + except BaseException as exc: + await self.aclose() + raise exc + + async def aclose(self) -> None: + """Asynchronously closes the underlying stream. + + If a custom asynchronous close function (`aclose_func`) is provided, it is awaited. + Otherwise, if the underlying stream has an `aclose` method, it is retrieved and awaited. + + Raises: + Any exception raised by the custom close function or the stream's `aclose` method. + + """ + if self._closed: + return + + self.closed = True + if self._aclose_func: + await self._aclose_func() + + elif hasattr(self._stream, "aclose"): + aclose = get_acallable_attribute(self._stream, "aclose") + if aclose: + await aclose() diff --git a/src/connect/handler.py b/src/connect/handler.py index 2d65d88..bb3be8a 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -6,7 +6,7 @@ from typing import Any import anyio -from starlette.responses import PlainTextResponse +from starlette.responses import PlainTextResponse, Response from connect.code import Code from connect.codec import Codec, CodecMap, CodecNameType, ProtoBinaryCodec, ProtoJSONCodec @@ -42,9 +42,8 @@ ) from connect.protocol_grpc.grpc_protocol import ProtocolGRPC from connect.request import Request -from connect.response import Response +from connect.response_writer import ServerResponseWriter from connect.utils import aiterate -from connect.writer import ServerResponseWriter type UnaryFunc[T_Request, T_Response] = Callable[[UnaryRequest[T_Request]], Awaitable[UnaryResponse[T_Response]]] type StreamFunc[T_Request, T_Response] = Callable[[StreamRequest[T_Request]], Awaitable[StreamResponse[T_Response]]] diff --git a/src/connect/middleware.py b/src/connect/middleware.py index 4a88d60..9c7f03b 100644 --- a/src/connect/middleware.py +++ b/src/connect/middleware.py @@ -6,8 +6,8 @@ from starlette.responses import Response from starlette.types import ASGIApp, Receive, Scope, Send +from connect.asgi_helpers.utils import get_route_path, request_response from connect.handler import Handler -from connect.utils import get_route_path, request_response HandleFunc = Callable[[Request], Awaitable[Response]] diff --git a/src/connect/protocol.py b/src/connect/protocol.py index 2bce004..348129d 100644 --- a/src/connect/protocol.py +++ b/src/connect/protocol.py @@ -20,8 +20,8 @@ from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel from connect.request import Request +from connect.response_writer import ServerResponseWriter from connect.session import AsyncClientSession -from connect.writer import ServerResponseWriter PROTOCOL_CONNECT = "connect" PROTOCOL_GRPC = "grpc" diff --git a/src/connect/protocol_connect/connect_client.py b/src/connect/protocol_connect/connect_client.py index ad31d7e..e8cb072 100644 --- a/src/connect/protocol_connect/connect_client.py +++ b/src/connect/protocol_connect/connect_client.py @@ -15,7 +15,6 @@ import httpcore from yarl import URL -from connect.byte_stream import HTTPCoreResponseAsyncByteStream from connect.code import Code from connect.codec import Codec, StableCodec from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name @@ -27,6 +26,7 @@ StreamType, ensure_single, ) +from connect.content_stream import BoundAsyncStream from connect.error import ConnectError from connect.headers import Headers, include_request_headers from connect.idempotency_level import IdempotencyLevel @@ -436,7 +436,7 @@ async def send( hook(response) assert isinstance(response.stream, AsyncIterable) - self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(response.stream) + self.unmarshaler.stream = BoundAsyncStream(response.stream) await self._validate_response(response) @@ -795,7 +795,7 @@ async def send( hook(response) assert isinstance(response.stream, AsyncIterable) - self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(aiterator=response.stream) + self.unmarshaler.stream = BoundAsyncStream(response.stream) await self._validate_response(response) diff --git a/src/connect/protocol_connect/connect_handler.py b/src/connect/protocol_connect/connect_handler.py index be1bcb6..e60753a 100644 --- a/src/connect/protocol_connect/connect_handler.py +++ b/src/connect/protocol_connect/connect_handler.py @@ -53,12 +53,11 @@ from connect.protocol_connect.marshaler import ConnectStreamingMarshaler, ConnectUnaryMarshaler from connect.protocol_connect.unmarshaler import ConnectStreamingUnmarshaler, ConnectUnaryUnmarshaler from connect.request import Request -from connect.response import Response -from connect.streaming_response import StreamingResponse +from connect.response import Response, StreamingResponse +from connect.response_writer import ServerResponseWriter from connect.utils import ( aiterate, ) -from connect.writer import ServerResponseWriter class ConnectHandler(ProtocolHandler): diff --git a/src/connect/protocol_grpc/grpc_client.py b/src/connect/protocol_grpc/grpc_client.py index 58ace49..51e69b9 100644 --- a/src/connect/protocol_grpc/grpc_client.py +++ b/src/connect/protocol_grpc/grpc_client.py @@ -10,7 +10,6 @@ import httpcore from yarl import URL -from connect.byte_stream import HTTPCoreResponseAsyncByteStream from connect.code import Code from connect.codec import Codec from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name @@ -20,6 +19,7 @@ StreamingClientConn, StreamType, ) +from connect.content_stream import BoundAsyncStream from connect.error import ConnectError from connect.headers import Headers, include_request_headers from connect.protocol import ( @@ -387,7 +387,7 @@ async def send( hook(response) assert isinstance(response.stream, AsyncIterable) - self.unmarshaler.stream = HTTPCoreResponseAsyncByteStream(aiterator=response.stream) + self.unmarshaler.stream = BoundAsyncStream(response.stream) self.receive_trailers = functools.partial(self._receive_trailers, response) await self._validate_response(response) diff --git a/src/connect/protocol_grpc/grpc_handler.py b/src/connect/protocol_grpc/grpc_handler.py index 98d573d..fdd34b1 100644 --- a/src/connect/protocol_grpc/grpc_handler.py +++ b/src/connect/protocol_grpc/grpc_handler.py @@ -35,9 +35,9 @@ from connect.protocol_grpc.marshaler import GRPCMarshaler from connect.protocol_grpc.unmarshaler import GRPCUnmarshaler from connect.request import Request -from connect.streaming_response import StreamingResponse +from connect.response import StreamingResponse +from connect.response_writer import ServerResponseWriter from connect.utils import aiterate -from connect.writer import ServerResponseWriter class GRPCHandler(ProtocolHandler): diff --git a/src/connect/response.py b/src/connect/response.py index 76ffee3..978b2e3 100644 --- a/src/connect/response.py +++ b/src/connect/response.py @@ -1,3 +1,150 @@ """Response module for the connect package.""" +import typing +from functools import partial +from typing import Any + +import anyio +from starlette._utils import collapse_excgroups # type: ignore +from starlette.background import BackgroundTask +from starlette.concurrency import iterate_in_threadpool +from starlette.requests import ClientDisconnect from starlette.responses import Response as Response +from starlette.types import Receive, Scope, Send + +ContentStream = typing.Iterable[typing.Any] | typing.AsyncIterable[typing.Any] +AsyncContentStream = typing.AsyncIterable[typing.Any] + + +class StreamingResponse(Response): + """A streaming HTTP response class that supports HTTP trailers. + + This class extends the standard response to allow sending HTTP trailers + at the end of a streamed response body, if supported by the ASGI server. + + Attributes: + body_iterator (AsyncContentStream): An asynchronous iterator over the response body content. + status_code (int): HTTP status code for the response. + media_type (str | None): The media type of the response. + background (BackgroundTask | None): Optional background task to run after response is sent. + headers (Mapping[str, str]): HTTP headers for the response. + _trailers (Mapping[str, str] | None): HTTP trailers to send after the response body. + + """ + + body_iterator: AsyncContentStream + + def __init__( + self, + content: ContentStream, + *, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + trailers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ) -> None: + """Initialize a response object with optional HTTP trailers. + + Args: + content (ContentStream): The response body content, which can be an async iterable or a regular iterable. + status_code (int, optional): HTTP status code for the response. Defaults to 200. + headers (typing.Mapping[str, str] | None, optional): HTTP headers to include in the response. Defaults to None. + trailers (typing.Mapping[str, str] | None, optional): HTTP trailers to include in the response. Defaults to None. + media_type (str | None, optional): The media type of the response. If None, uses the default media type. Defaults to None. + background (BackgroundTask | None, optional): A background task to run after the response is sent. Defaults to None. + + Notes: + - If `content` is not an async iterable, it will be wrapped to run in a thread pool. + - If trailers are provided, their names will be added to the "Trailer" header. + + """ + if isinstance(content, typing.AsyncIterable): + self.body_iterator = content + else: + self.body_iterator = iterate_in_threadpool(content) + + self.status_code = status_code + self.media_type = self.media_type if media_type is None else media_type + self.background = background + self.init_headers(headers) + self._trailers = trailers + + if self._trailers: + names = ", ".join({k for k, _ in self._trailers.items()}) + if names: + self.headers.setdefault("Trailer", names) + + async def _stream_response(self, send: Send, trailers_supported: bool) -> None: + await send({ + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + "trailers": self._trailers is not None and trailers_supported, + }) + + async for chunk in self.body_iterator: + if not isinstance(chunk, bytes | memoryview): + chunk = chunk.encode(self.charset) + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + if self._trailers is not None and trailers_supported: + encoded_headers = [(key.encode(), value.encode()) for key, value in self._trailers.items()] + await send({ + "type": "http.response.trailers", + "headers": encoded_headers, + "more_trailers": False, + }) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Handle the ASGI call interface for streaming HTTP responses with optional support for HTTP trailers. + + This method determines the ASGI spec version and whether HTTP response trailers are supported. + For ASGI spec version >= 2.4, it streams the response and handles client disconnects. + For earlier versions, it concurrently streams the response and listens for client disconnects, + cancelling the response stream if a disconnect is detected. + + After sending the response, if a background task is provided, it is awaited. + + Args: + scope (Scope): The ASGI connection scope. + receive (Receive): Awaitable callable to receive ASGI messages. + send (Send): Awaitable callable to send ASGI messages. + + Raises: + ClientDisconnect: If the client disconnects during response streaming. + + """ + spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) + trailers_supported = "http.response.trailers" in scope.get("extensions", {}) + + if spec_version >= (2, 4): + try: + await self._stream_response(send, trailers_supported) + except OSError: + raise ClientDisconnect() from None + + else: + + async def listen_for_disconnect() -> None: + while True: + if (await receive())["type"] == "http.disconnect": + break + + with collapse_excgroups(): + async with anyio.create_task_group() as tg: + + async def run_and_cancel(func: Any) -> None: + await func() + tg.cancel_scope.cancel() + + tg.start_soon( + run_and_cancel, + partial(self._stream_response, send, trailers_supported), + ) + await run_and_cancel(listen_for_disconnect) + + if self.background is not None: + await self.background() diff --git a/src/connect/writer.py b/src/connect/response_writer.py similarity index 100% rename from src/connect/writer.py rename to src/connect/response_writer.py diff --git a/src/connect/streaming_response.py b/src/connect/streaming_response.py deleted file mode 100644 index 38b2a96..0000000 --- a/src/connect/streaming_response.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Streaming HTTP response with support for trailers.""" - -import typing -from functools import partial -from typing import Any - -import anyio -from starlette._utils import collapse_excgroups # type: ignore -from starlette.background import BackgroundTask -from starlette.concurrency import iterate_in_threadpool -from starlette.requests import ClientDisconnect -from starlette.responses import Response -from starlette.types import Receive, Scope, Send - -ContentStream = typing.Iterable[typing.Any] | typing.AsyncIterable[typing.Any] -AsyncContentStream = typing.AsyncIterable[typing.Any] - - -class StreamingResponse(Response): - """A streaming HTTP response class that supports HTTP trailers. - - This class extends the standard response to allow sending HTTP trailers - at the end of a streamed response body, if supported by the ASGI server. - - Attributes: - body_iterator (AsyncContentStream): An asynchronous iterator over the response body content. - status_code (int): HTTP status code for the response. - media_type (str | None): The media type of the response. - background (BackgroundTask | None): Optional background task to run after response is sent. - headers (Mapping[str, str]): HTTP headers for the response. - _trailers (Mapping[str, str] | None): HTTP trailers to send after the response body. - - """ - - body_iterator: AsyncContentStream - - def __init__( - self, - content: ContentStream, - *, - status_code: int = 200, - headers: typing.Mapping[str, str] | None = None, - trailers: typing.Mapping[str, str] | None = None, - media_type: str | None = None, - background: BackgroundTask | None = None, - ) -> None: - """Initialize a response object with optional HTTP trailers. - - Args: - content (ContentStream): The response body content, which can be an async iterable or a regular iterable. - status_code (int, optional): HTTP status code for the response. Defaults to 200. - headers (typing.Mapping[str, str] | None, optional): HTTP headers to include in the response. Defaults to None. - trailers (typing.Mapping[str, str] | None, optional): HTTP trailers to include in the response. Defaults to None. - media_type (str | None, optional): The media type of the response. If None, uses the default media type. Defaults to None. - background (BackgroundTask | None, optional): A background task to run after the response is sent. Defaults to None. - - Notes: - - If `content` is not an async iterable, it will be wrapped to run in a thread pool. - - If trailers are provided, their names will be added to the "Trailer" header. - - """ - if isinstance(content, typing.AsyncIterable): - self.body_iterator = content - else: - self.body_iterator = iterate_in_threadpool(content) - - self.status_code = status_code - self.media_type = self.media_type if media_type is None else media_type - self.background = background - self.init_headers(headers) - self._trailers = trailers - - if self._trailers: - names = ", ".join({k for k, _ in self._trailers.items()}) - if names: - self.headers.setdefault("Trailer", names) - - async def _stream_response(self, send: Send, trailers_supported: bool) -> None: - await send({ - "type": "http.response.start", - "status": self.status_code, - "headers": self.raw_headers, - "trailers": self._trailers is not None and trailers_supported, - }) - - async for chunk in self.body_iterator: - if not isinstance(chunk, bytes | memoryview): - chunk = chunk.encode(self.charset) - await send({"type": "http.response.body", "body": chunk, "more_body": True}) - - await send({"type": "http.response.body", "body": b"", "more_body": False}) - - if self._trailers is not None and trailers_supported: - encoded_headers = [(key.encode(), value.encode()) for key, value in self._trailers.items()] - await send({ - "type": "http.response.trailers", - "headers": encoded_headers, - "more_trailers": False, - }) - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Handle the ASGI call interface for streaming HTTP responses with optional support for HTTP trailers. - - This method determines the ASGI spec version and whether HTTP response trailers are supported. - For ASGI spec version >= 2.4, it streams the response and handles client disconnects. - For earlier versions, it concurrently streams the response and listens for client disconnects, - cancelling the response stream if a disconnect is detected. - - After sending the response, if a background task is provided, it is awaited. - - Args: - scope (Scope): The ASGI connection scope. - receive (Receive): Awaitable callable to receive ASGI messages. - send (Send): Awaitable callable to send ASGI messages. - - Raises: - ClientDisconnect: If the client disconnects during response streaming. - - """ - spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) - trailers_supported = "http.response.trailers" in scope.get("extensions", {}) - - if spec_version >= (2, 4): - try: - await self._stream_response(send, trailers_supported) - except OSError: - raise ClientDisconnect() from None - - else: - - async def listen_for_disconnect() -> None: - while True: - if (await receive())["type"] == "http.disconnect": - break - - with collapse_excgroups(): - async with anyio.create_task_group() as tg: - - async def run_and_cancel(func: Any) -> None: - await func() - tg.cancel_scope.cancel() - - tg.start_soon( - run_and_cancel, - partial(self._stream_response, send, trailers_supported), - ) - await run_and_cancel(listen_for_disconnect) - - if self.background is not None: - await self.background() diff --git a/src/connect/utils.py b/src/connect/utils.py index 3e71cb5..a0f86e8 100644 --- a/src/connect/utils.py +++ b/src/connect/utils.py @@ -5,16 +5,11 @@ import functools import typing from collections.abc import ( - Awaitable, - Callable, Iterator, ) import anyio.to_thread import httpcore -from starlette.requests import Request -from starlette.responses import Response -from starlette.types import ASGIApp, Receive, Scope, Send from connect.code import Code from connect.error import ConnectError @@ -117,188 +112,6 @@ def get_acallable_attribute(obj: object, attr: str) -> typing.Callable[..., typi return None -def get_route_path(scope: Scope) -> str: - """Extract the route path from the given scope. - - Args: - scope (Scope): The scope dictionary containing the request information. - - Returns: - str: The extracted route path. If a root path is specified in the scope, - the function returns the path relative to the root path. If the path - does not start with the root path or if the path is equal to the root - path, the function returns the original path or an empty string, - respectively. - - """ - path: str = scope["path"] - root_path = scope.get("root_path", "") - if not root_path: - return path - - if not path.startswith(root_path): - return path - - if path == root_path: - return "" - - if path[len(root_path)] == "/": - return path[len(root_path) :] - - return path - - -def request_response(func: Callable[[Request], Awaitable[Response] | Response]) -> ASGIApp: - """Convert a request handler function into an ASGI application. - - This decorator takes a function that handles a request and returns a response, - and wraps it into an ASGI application callable. The handler function can be either - synchronous or asynchronous. - - Args: - func (Callable[[Request], Awaitable[Response] | Response]): The request handler function. - It can be a synchronous function returning a Response or an asynchronous function - returning an Awaitable of Response. - - Returns: - ASGIApp: An ASGI application callable that can be used to handle ASGI requests. - - """ - - async def async_func(request: Request) -> Response: - if is_async_callable(func): - return await func(request) - else: - return typing.cast(Response, await run_in_threadpool(func, request)) - - f: Callable[[Request], Awaitable[Response]] = async_func - - async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request(scope, receive, send) - response = await f(request) - await response(scope, receive, send) - - return app - - -class AsyncByteStream(typing.AsyncIterable[bytes]): - """An abstract base class for asynchronous byte streams. - - This class defines the interface for an asynchronous byte stream, which - includes methods for iterating over the stream and closing it. - - """ - - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - """Asynchronous iterator method. - - This method should be implemented to provide asynchronous iteration - over the object. It must return an asynchronous iterator that yields - bytes. - - Raises: - NotImplementedError: If the method is not implemented. - - """ - raise NotImplementedError("The '__aiter__' method must be implemented.") # pragma: no cover - yield b"" - - async def aclose(self) -> None: - """Asynchronously close the byte stream.""" - pass - - -class StreamConsumedError(Exception): - """Exception raised when a stream has already been consumed.""" - - def __init__(self) -> None: - """Initialize the exception with a default message.""" - super().__init__("Stream has already been consumed.") - - -class AsyncDataStream[T]: - """AsyncDataStream is a generic class that provides an asynchronous iterable interface for streaming data. - - It ensures that the stream is consumed only once and provides a mechanism for resource cleanup. - Type Parameters: - T: The type of elements in the asynchronous stream. - - Attributes: - _stream (typing.AsyncIterable[T]): The asynchronous iterable representing the stream of data. - _is_stream_consumed (bool): A flag indicating whether the stream has already been consumed. - aclose_func (Callable[..., Awaitable[None]] | None): An optional asynchronous callable for closing resources. - - Methods: - __init__(stream: typing.AsyncIterable[T], aclose_func: Callable[..., Awaitable[None]] | None = None) -> None: - Initializes the AsyncDataStream instance with the given stream and optional close function. - __aiter__() -> typing.AsyncIterator[T]: - Asynchronously iterates over the elements of the stream. Ensures the stream is consumed only once - and handles cleanup in case of exceptions. - aclose() -> None: - Asynchronously closes resources if an asynchronous close function is provided. - - """ - - _stream: typing.AsyncIterable[T] - _is_stream_consumed: bool - aclose_func: Callable[..., Awaitable[None]] | None - - def __init__( - self, stream: typing.AsyncIterable[T], aclose_func: Callable[..., Awaitable[None]] | None = None - ) -> None: - """Initialize an instance of the class. - - Args: - stream (typing.AsyncIterable[T]): An asynchronous iterable representing the stream of data. - aclose_func (Callable[..., Awaitable[None]] | None, optional): - A callable function that is awaited to close the stream. Defaults to None. - - """ - self._stream = stream - self._is_stream_consumed = False - self.aclose_func = aclose_func - - async def __aiter__(self) -> typing.AsyncIterator[T]: - """Asynchronously iterates over the elements of the stream. - - This method allows the object to be used as an asynchronous iterator. - It ensures that the stream is not consumed multiple times and properly - handles cleanup in case of exceptions. - - Yields: - T: The next element in the asynchronous stream. - - Raises: - StreamConsumedError: If the stream has already been consumed. - BaseException: Propagates any exception raised during iteration - after ensuring the stream is closed. - - """ - if self._is_stream_consumed: - raise StreamConsumedError() - - self._is_stream_consumed = True - try: - async for part in self._stream: - yield part - except BaseException as exc: - await self.aclose() - raise exc - - async def aclose(self) -> None: - """Asynchronously closes resources if an asynchronous close function is provided. - - This method checks if an `aclose_func` is defined. If it is, the function - is awaited to perform any necessary cleanup or resource deallocation. - - Returns: - None - - """ - if self.aclose_func: - await self.aclose_func() - - async def aiterate[T](iterable: typing.Iterable[T]) -> typing.AsyncIterator[T]: """Turn a plain iterable into an async iterator. @@ -364,22 +177,3 @@ def map_httpcore_exceptions() -> Iterator[None]: raise ConnectError(str(exc), to_code) from exc raise exc - - -async def achain[T](*itrs: typing.AsyncIterable[T]) -> typing.AsyncIterator[T]: - """Asynchronously chains multiple async iterables into a single async iterator. - - Args: - *itrs (typing.AsyncIterable[T]): A variable number of async iterables to be chained. - - Yields: - T: Items from the provided async iterables, in the order they are received. - - Example: - async for item in achain(async_iterable1, async_iterable2): - print(item) - - """ - for itr in itrs: - async for item in itr: - yield item