From 99cb6bd9cc816a48614d002bf85327c023d3d8df Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Thu, 24 Apr 2025 22:08:08 +0900 Subject: [PATCH 1/4] handler: refactor handler conn --- src/connect/connect.py | 172 ++++++--------------- src/connect/handler.py | 39 +++-- src/connect/protocol.py | 38 ++--- src/connect/protocol_connect.py | 257 +++++++++++++++----------------- src/connect/protocol_grpc.py | 257 +++++++++++--------------------- 5 files changed, 286 insertions(+), 477 deletions(-) diff --git a/src/connect/connect.py b/src/connect/connect.py index 998f8f3..4692a62 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -318,125 +318,6 @@ async def aclose(self) -> None: await aclose() -class UnaryHandlerConn(abc.ABC): - """Abstract base class for a streaming handler connection. - - This class defines the interface for handling streaming connections, including - methods for specifying the connection, handling peer communication, receiving - and sending messages, and managing request and response headers and trailers. - - """ - - @abc.abstractmethod - def parse_timeout(self) -> float | None: - """Parse the timeout value.""" - raise NotImplementedError() - - @property - @abc.abstractmethod - def spec(self) -> Spec: - """Return the specification details. - - Returns: - Spec: The specification details. - - """ - raise NotImplementedError() - - @property - @abc.abstractmethod - def peer(self) -> Peer: - """Establish a connection to a peer in the network. - - Returns: - Any: The result of the connection attempt. The exact type and structure - of the return value will depend on the implementation details. - - """ - raise NotImplementedError() - - @abc.abstractmethod - async def receive(self, message: Any) -> Any: - """Receives a message and processes it. - - Args: - message (Any): The message to be received and processed. - - Returns: - Any: The result of processing the message. - - """ - raise NotImplementedError() - - @property - @abc.abstractmethod - def request_headers(self) -> Headers: - """Generate and return the request headers. - - Returns: - Any: The request headers. - - """ - raise NotImplementedError() - - @abc.abstractmethod - async def send(self, message: Any) -> None: - """Send a message. - - This method should be implemented by subclasses to define how the message - should be sent. - - Args: - message (Any): The message to be sent. - - Raises: - NotImplementedError: If the method is not implemented by a subclass. - - """ - raise NotImplementedError() - - @property - @abc.abstractmethod - def response_headers(self) -> Headers: - """Retrieve the response headers. - - Returns: - Any: The response headers. - - """ - raise NotImplementedError() - - @property - @abc.abstractmethod - def response_trailers(self) -> Headers: - """Handle response trailers. - - This method is intended to be overridden in subclasses to provide - specific functionality for processing response trailers. - - Returns: - Any: The return type is not specified as this is a placeholder method. - - """ - raise NotImplementedError() - - @abc.abstractmethod - async def send_error(self, error: ConnectError) -> None: - """Send an error message. - - This method should be implemented to handle the sending of error messages - in a specific manner defined by the subclass. - - Args: - error (ConnectError): The error to be sent. - - Raises: - NotImplementedError: If the method is not implemented by the subclass. - - """ - raise NotImplementedError() - - class StreamingHandlerConn(abc.ABC): """Abstract base class for a streaming handler connection. @@ -482,7 +363,8 @@ def receive(self, message: Any) -> AsyncIterator[Any]: message (Any): The message to be received and processed. Returns: - Any: The result of processing the message. + AsyncIterator[Any]: An async iterator of processing results. + For unary operations, this will yield exactly one item. """ raise NotImplementedError() @@ -493,7 +375,7 @@ def request_headers(self) -> Headers: """Generate and return the request headers. Returns: - Any: The request headers. + Headers: The request headers. """ raise NotImplementedError() @@ -504,6 +386,7 @@ async def send(self, messages: AsyncIterable[Any]) -> None: Args: messages (AsyncIterable[Any]): The messages to be sent. + For unary operations, this should be an iterable with a single item. Raises: NotImplementedError: This method should be implemented by subclasses. @@ -517,7 +400,7 @@ def response_headers(self) -> Headers: """Retrieve the response headers. Returns: - Any: The response headers. + Headers: The response headers. """ raise NotImplementedError() @@ -531,7 +414,7 @@ def response_trailers(self) -> Headers: specific functionality for processing response trailers. Returns: - Any: The return type is not specified as this is a placeholder method. + Headers: The response trailers. """ raise NotImplementedError() @@ -552,6 +435,15 @@ async def send_error(self, error: ConnectError) -> None: """ raise NotImplementedError() + def is_unary(self) -> bool: + """Check if this connection is for a unary operation. + + Returns: + bool: True if this is a unary operation connection, False otherwise. + + """ + return self.spec.stream_type == StreamType.Unary + class UnaryClientConn: """Abstract base class for a streaming client connection.""" @@ -569,7 +461,7 @@ def peer(self) -> Peer: raise NotImplementedError() @abc.abstractmethod - async def receive(self, message: Any) -> Any: + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message and processes it.""" raise NotImplementedError() @@ -679,14 +571,14 @@ def spec(self) -> Spec: raise NotImplementedError() @abc.abstractmethod - async def receive(self, message: Any) -> Any: + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message and processes it. Args: message (Any): The message to be received and processed. Returns: - Any: The result of processing the message. + AsyncIterator[Any]: An async iterator of processing results. Raises: NotImplementedError: This method should be implemented by subclasses. @@ -695,11 +587,11 @@ async def receive(self, message: Any) -> Any: raise NotImplementedError() -async def receive_unary_request[T](conn: UnaryHandlerConn, t: type[T]) -> UnaryRequest[T]: +async def receive_unary_request[T](conn: StreamingHandlerConn, t: type[T]) -> UnaryRequest[T]: """Receives a unary request from the given connection and returns a UnaryRequest object. Args: - conn (UnaryHandlerConn): The connection from which to receive the unary request. + conn (StreamingHandlerConn): The connection from which to receive the unary request. t (type[T]): The type of the message to be received. Returns: @@ -878,6 +770,26 @@ async def receive_unary_message[T](conn: ReceiveConn, t: type[T]) -> T: Returns: T: The received message of type T. + Raises: + ConnectError: If no message is received or multiple messages are received. + """ - message = await conn.receive(t) - return message + first = None + count = 0 + + async for message in conn.receive(t): + count += 1 + if count > 1: + raise ConnectError( + f"received extra input message for {conn.spec.procedure} method", + Code.UNIMPLEMENTED, + ) + first = message + + if first is None: + raise ConnectError( + f"missing input message for {conn.spec.procedure} method", + Code.UNIMPLEMENTED, + ) + + return first diff --git a/src/connect/handler.py b/src/connect/handler.py index 4a53f95..073e39b 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -1,7 +1,6 @@ """Module provides handler configurations and implementations for unary procedures and stream types.""" import asyncio -import inspect import logging from collections.abc import Awaitable, Callable from http import HTTPMethod, HTTPStatus @@ -19,7 +18,6 @@ StreamRequest, StreamResponse, StreamType, - UnaryHandlerConn, UnaryRequest, UnaryResponse, receive_stream_request, @@ -46,6 +44,7 @@ from connect.protocol_grpc import ProtocolGPRC from connect.request import Request from connect.response import Response +from connect.utils import aiterate from connect.writer import ServerResponseWriter logging.basicConfig(level=logging.INFO) @@ -148,7 +147,7 @@ def create_protocol_handlers(config: HandlerConfig) -> list[ProtocolHandler]: return handlers -UnaryImplementationFunc = Callable[[UnaryHandlerConn, float | None], Awaitable[None]] +UnaryImplementationFunc = Callable[[StreamingHandlerConn, float | None], Awaitable[None]] StreamImplementationFunc = Callable[[StreamingHandlerConn, float | None], Awaitable[None]] @@ -291,9 +290,12 @@ async def handle(self, request: Request) -> Response: async def _handle( self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter ) -> None: - if self.is_stream(self.implementation): + # Check the stream type of the handler + if getattr(self, "stream_type", StreamType.Unary) != StreamType.Unary: + self._is_stream_handler = True await self.stream_handle(request, response_headers, response_trailers, writer) else: + self._is_stream_handler = False await self.unary_handle(request, response_headers, response_trailers, writer) def is_stream( @@ -308,9 +310,9 @@ def is_stream( TypeGuard[StreamImplementationFunc]: True if the implementation function is a stream implementation, False otherwise. """ - signature = inspect.signature(impl) - parameters = signature.parameters - return len(parameters) == 2 and next(iter(parameters.values())).annotation == StreamingHandlerConn + # Since we've consolidated to a single connection type, use a sentinel value in the handler config + is_stream_handler = getattr(self, "_is_stream_handler", False) + return is_stream_handler def is_unary(self, impl: UnaryImplementationFunc | StreamImplementationFunc) -> TypeGuard[UnaryImplementationFunc]: """Determine if the given implementation function is a unary implementation. @@ -322,9 +324,9 @@ def is_unary(self, impl: UnaryImplementationFunc | StreamImplementationFunc) -> TypeGuard[UnaryImplementationFunc]: True if the implementation function is a unary implementation, False otherwise. """ - signature = inspect.signature(impl) - parameters = signature.parameters - return len(parameters) == 2 and next(iter(parameters.values())).annotation == UnaryHandlerConn + # Since we've consolidated to a single connection type, use a sentinel value in the handler config + is_stream_handler = getattr(self, "_is_stream_handler", False) + return not is_stream_handler async def stream_handle( self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter @@ -345,7 +347,8 @@ async def stream_handle( ConnectError: If an internal error occurs during the handling of the stream. """ - conn = await self.protocol_handler.stream_conn(request, response_headers, response_trailers, writer) + self._is_stream_handler = True + conn = await self.protocol_handler.conn(request, response_headers, response_trailers, writer, is_streaming=True) if conn is None: return @@ -387,6 +390,7 @@ async def unary_handle( None """ + self._is_stream_handler = False conn = await self.protocol_handler.conn(request, response_headers, response_trailers, writer) if conn is None: return @@ -399,7 +403,6 @@ async def unary_handle( timeout = conn.parse_timeout() if timeout: timeout_ms = int(timeout * 1000) - with anyio.fail_after(delay=timeout): await implementation(conn, timeout_ms) else: @@ -413,7 +416,6 @@ async def unary_handle( if isinstance(e, NotImplementedError): error = ConnectError("not implemented", Code.UNIMPLEMENTED) - await conn.send_error(error) @@ -438,6 +440,7 @@ class UnaryHandler[T_Request, T_Response](Handler): protocol_handlers: dict[HTTPMethod, list[ProtocolHandler]] allow_methods: str accept_post: str + stream_type: StreamType = StreamType.Unary def __init__( self, @@ -460,7 +463,7 @@ async def _untyped(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Response untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: UnaryHandlerConn, timeout: float | None) -> None: + async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None: request = await receive_unary_request(conn, input) if timeout: request.timeout = timeout @@ -475,7 +478,7 @@ async def implementation(conn: UnaryHandlerConn, timeout: float | None) -> None: conn.response_headers.update(exclude_protocol_headers(response.headers)) conn.response_trailers.update(exclude_protocol_headers(response.trailers)) - await conn.send(response.message) + await conn.send(aiterate([response.message])) super().__init__( procedure=procedure, @@ -503,6 +506,8 @@ class ServerStreamHandler[T_Request, T_Response](Handler): """ + stream_type: StreamType = StreamType.ServerStream + def __init__( self, procedure: str, @@ -575,6 +580,8 @@ class ClientStreamHandler[T_Request, T_Response](Handler): """ + stream_type: StreamType = StreamType.ClientStream + def __init__( self, procedure: str, @@ -664,6 +671,8 @@ class BidiStreamHandler[T_Request, T_Response](Handler): """ + stream_type: StreamType = StreamType.BiDiStream + def __init__( self, procedure: str, diff --git a/src/connect/protocol.py b/src/connect/protocol.py index 2d2ae83..93ed608 100644 --- a/src/connect/protocol.py +++ b/src/connect/protocol.py @@ -16,7 +16,6 @@ StreamingHandlerConn, StreamType, UnaryClientConn, - UnaryHandlerConn, ) from connect.error import ConnectError from connect.headers import Headers @@ -108,7 +107,7 @@ def stream_conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: raise NotImplementedError() -HanderConn = UnaryHandlerConn | StreamingHandlerConn +HanderConn = StreamingHandlerConn class ProtocolHandler(abc.ABC): @@ -157,39 +156,24 @@ def can_handle_payload(self, request: Request, content_type: str) -> bool: @abc.abstractmethod async def conn( - self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter - ) -> UnaryHandlerConn | None: - """Handle a unary connection request. - - Args: - request (Request): The incoming request object. - response_headers (Headers): The headers to be sent in the response. - response_trailers (Headers): The trailers to be sent in the response. - writer (ServerResponseWriter): The writer used to send the response. - - Returns: - UnaryHandlerConn | None: The connection handler or None if not implemented. - - Raises: - NotImplementedError: If the method is not implemented. - - """ - raise NotImplementedError() - - @abc.abstractmethod - async def stream_conn( - self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter + self, + request: Request, + response_headers: Headers, + response_trailers: Headers, + writer: ServerResponseWriter, + is_streaming: bool = False, ) -> StreamingHandlerConn | None: - """Handle a streaming connection. + """Handle a connection request. Args: request (Request): The incoming request object. response_headers (Headers): The headers to be sent in the response. response_trailers (Headers): The trailers to be sent in the response. - writer (ServerResponseWriter): The writer object to send the response. + writer (ServerResponseWriter): The writer used to send the response. + is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. Returns: - StreamingHandlerConn | None: The streaming handler connection object or None if not implemented. + StreamingHandlerConn | None: The connection handler or None if not implemented. Raises: NotImplementedError: If the method is not implemented. diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index c7c1c04..15742d9 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -32,7 +32,6 @@ StreamingHandlerConn, StreamType, UnaryClientConn, - UnaryHandlerConn, ) from connect.envelope import Envelope, EnvelopeFlags from connect.error import DEFAULT_ANY_RESOLVER_PREFIX, ConnectError, ErrorDetail @@ -175,116 +174,44 @@ def can_handle_payload(self, request: Request, content_type: str) -> bool: return content_type in self.accept - async def stream_conn( - self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter + async def conn( + self, + request: Request, + response_headers: Headers, + response_trailers: Headers, + writer: ServerResponseWriter, + is_streaming: bool = False, ) -> StreamingHandlerConn | None: - """Establish a streaming connection for the given request and response. + """Handle a connection request. Args: request (Request): The incoming request object. - response_headers (Headers): Headers to be sent in the response. - response_trailers (Headers): Trailers to be sent in the response. - writer (ServerResponseWriter): The writer to send the response. + response_headers (Headers): The headers to be sent in the response. + response_trailers (Headers): The trailers to be sent in the response. + writer (ServerResponseWriter): The writer used to send the response. + is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. Returns: - StreamingHandlerConn | None: The streaming connection handler if no error occurs, otherwise None. + StreamingHandlerConn | None: The connection handler or None if not implemented. Raises: ConnectError: If there is an error in negotiating compression, protocol version, or message encoding. - """ - content_encoding = request.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, None) - accept_encoding = request.headers.get(CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, None) - - request_compression, response_compression, error = negotiate_compression( - self.params.compressions, content_encoding, accept_encoding - ) - - if error is None: - required = self.params.require_connect_protocol_header and self.params.spec.stream_type == StreamType.Unary - error = connect_check_protocol_version(request, required) - - content_type = request.headers.get(HEADER_CONTENT_TYPE, "") - codec_name = connect_codec_from_content_type(self.params.spec.stream_type, content_type) - - codec = self.params.codecs.get(codec_name) - if error is None and codec is None: - error = ConnectError( - f"invalid message encoding: {codec_name}", - Code.INVALID_ARGUMENT, - ) - - response_headers[HEADER_CONTENT_TYPE] = content_type - - if response_compression and response_compression.name != COMPRESSION_IDENTITY: - response_headers[CONNECT_STREAMING_HEADER_COMPRESSION] = response_compression.name - - response_headers[CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION] = ( - f"{', '.join(c.name for c in self.params.compressions)}" - ) - - peer = Peer( - address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, - protocol=PROTOCOL_CONNECT, - query=request.query_params, - ) - - stream_conn = ConnectStreamingHandlerConn( - writer=writer, - request=request, - peer=peer, - spec=self.params.spec, - marshaler=ConnectStreamingMarshaler( - codec=codec, - compress_min_bytes=self.params.compress_min_bytes, - send_max_bytes=self.params.send_max_bytes, - compression=response_compression, - ), - unmarshaler=ConnectStreamingUnmarshaler( - stream=request.stream(), - codec=codec, - compression=request_compression, - read_max_bytes=self.params.read_max_bytes, - ), - request_headers=Headers(request.headers, encoding="latin-1"), - response_headers=response_headers, - response_trailers=response_trailers, - ) - - if error: - await stream_conn.send_error(error) - return None - - return stream_conn - - async def conn( - self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter - ) -> UnaryHandlerConn | None: - """Handle an incoming connection request and returns a UnaryHandlerConn object if successful. - - Args: - request (Request): The incoming request object. - response_headers (Headers): The headers to be included in the response. - response_trailers (Headers): The trailers to be included in the response. - writer (ServerResponseWriter): The writer object to send the response. - - Returns: - UnaryHandlerConn | None: A UnaryHandlerConn object if the connection is successfully established, - otherwise None if an error occurs. - - Raises: - ConnectError: If there are issues with the request parameters, encoding, or protocol version. - """ query_params = request.query_params - if HTTPMethod(request.method) == HTTPMethod.GET: - content_encoding = query_params.get(CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER, None) + # Get content encoding and accept encoding appropriately based on stream type + if is_streaming: + content_encoding = request.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, None) + accept_encoding = request.headers.get(CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, None) else: - content_encoding = request.headers.get(CONNECT_UNARY_HEADER_COMPRESSION, None) - - accept_encoding = request.headers.get(CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION, None) + if HTTPMethod(request.method) == HTTPMethod.GET: + content_encoding = query_params.get(CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER, None) + else: + content_encoding = request.headers.get(CONNECT_UNARY_HEADER_COMPRESSION, None) + accept_encoding = request.headers.get(CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION, None) + # Negotiate compression request_compression, response_compression, error = negotiate_compression( self.params.compressions, content_encoding, accept_encoding ) @@ -293,7 +220,8 @@ async def conn( required = self.params.require_connect_protocol_header and self.params.spec.stream_type == StreamType.Unary error = connect_check_protocol_version(request, required) - if HTTPMethod(request.method) == HTTPMethod.GET: + # Process GET parameters for unary requests + if not is_streaming and HTTPMethod(request.method) == HTTPMethod.GET: encoding = query_params.get(CONNECT_UNARY_ENCODING_QUERY_PARAMETER, "") message = query_params.get(CONNECT_UNARY_MESSAGE_QUERY_PARAMETER, "") if error is None and encoding == "": @@ -325,6 +253,7 @@ async def conn( content_type = request.headers.get(HEADER_CONTENT_TYPE, "") codec_name = connect_codec_from_content_type(self.params.spec.stream_type, content_type) + # Get codec codec = self.params.codecs.get(codec_name) if error is None and codec is None: error = ConnectError( @@ -332,40 +261,76 @@ async def conn( Code.INVALID_ARGUMENT, ) + # Set response headers response_headers[HEADER_CONTENT_TYPE] = content_type - response_headers[CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION] = ( - f"{', '.join(c.name for c in self.params.compressions)}" - ) + if is_streaming: + if response_compression and response_compression.name != COMPRESSION_IDENTITY: + response_headers[CONNECT_STREAMING_HEADER_COMPRESSION] = response_compression.name + response_headers[CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION] = ( + f"{', '.join(c.name for c in self.params.compressions)}" + ) + else: + response_headers[CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION] = ( + f"{', '.join(c.name for c in self.params.compressions)}" + ) + + # Create peer peer = Peer( address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, protocol=PROTOCOL_CONNECT, query=request.query_params, ) - conn = ConnectUnaryHandlerConn( - writer=writer, - request=request, - peer=peer, - spec=self.params.spec, - marshaler=ConnectUnaryMarshaler( - codec=codec, - compress_min_bytes=self.params.compress_min_bytes, - send_max_bytes=self.params.send_max_bytes, - compression=response_compression, - headers=response_headers, - ), - unmarshaler=ConnectUnaryUnmarshaler( - stream=request_stream, - codec=codec, - compression=request_compression, - read_max_bytes=self.params.read_max_bytes, - ), - request_headers=Headers(request.headers, encoding="latin-1"), - response_headers=response_headers, - response_trailers=response_trailers, - ) + # Create appropriate handler connection based on stream type + conn: StreamingHandlerConn | None = None + if is_streaming: + conn = ConnectStreamingHandlerConn( + writer=writer, + request=request, + peer=peer, + spec=self.params.spec, + marshaler=ConnectStreamingMarshaler( + codec=codec, + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + compression=response_compression, + ), + unmarshaler=ConnectStreamingUnmarshaler( + stream=request.stream(), + codec=codec, + compression=request_compression, + read_max_bytes=self.params.read_max_bytes, + ), + request_headers=Headers(request.headers, encoding="latin-1"), + response_headers=response_headers, + response_trailers=response_trailers, + ) + else: + conn = ConnectUnaryHandlerConn( + writer=writer, + request=request, + peer=peer, + spec=self.params.spec, + marshaler=ConnectUnaryMarshaler( + codec=codec, + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + compression=response_compression, + headers=response_headers, + ), + unmarshaler=ConnectUnaryUnmarshaler( + stream=request_stream, + codec=codec, + compression=request_compression, + read_max_bytes=self.params.read_max_bytes, + ), + request_headers=Headers(request.headers, encoding="latin-1"), + response_headers=response_headers, + response_trailers=response_trailers, + ) + if error: await conn.send_error(error) return None @@ -620,7 +585,7 @@ def marshal(self, message: Any) -> bytes: return data -class ConnectUnaryHandlerConn(UnaryHandlerConn): +class ConnectUnaryHandlerConn(StreamingHandlerConn): """ConnectUnaryHandlerConn is a handler connection class for unary RPCs in the Connect protocol. Attributes: @@ -709,18 +674,22 @@ def peer(self) -> Peer: """ return self._peer - async def receive(self, message: Any) -> Any: + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message, unmarshals it, and returns the resulting object. Args: message (Any): The message to be unmarshaled. Returns: - Any: The unmarshaled object. + AsyncIterator[Any]: An async iterator yielding the unmarshaled object. """ - obj = await self.unmarshaler.unmarshal(message) - return obj + + async def _receive() -> AsyncIterator[Any]: + obj = await self.unmarshaler.unmarshal(message) + yield obj + + return _receive() @property def request_headers(self) -> Headers: @@ -732,18 +701,32 @@ def request_headers(self) -> Headers: """ return self._request_headers - async def send(self, message: Any) -> None: - """Send a message by marshaling it into bytes. + async def send(self, messages: AsyncIterable[Any]) -> None: + """Send message(s) by marshaling them into bytes. Args: - message (Any): The message to be sent. + messages (AsyncIterable[Any]): The message(s) to be sent. For unary operations, + this should be an iterable with a single item. Returns: - bytes: The marshaled message in bytes. + None """ self.merge_response_trailers() + message = None + count = 0 + + async for msg in messages: + count += 1 + if count > 1: + raise ConnectError("unary handler should only send one message", Code.INTERNAL) + + message = msg + + if message is None: + raise ConnectError("unary handler must send one message", Code.INTERNAL) + data = self.marshaler.marshal(message) response = Response(content=data, headers=self.response_headers, status_code=HTTPStatus.OK) await self.writer.write(response) @@ -2053,18 +2036,22 @@ def peer(self) -> Peer: """ return self._peer - async def receive(self, message: Any) -> Any: + def receive(self, message: Any) -> AsyncIterator[Any]: """Asynchronously receives a message, unmarshals it, and returns the resulting object. Args: message (Any): The message to be unmarshaled. Returns: - None: This method does not return a value. The unmarshaled object is returned implicitly. + AsyncIterator[Any]: An async iterator yielding the unmarshaled object. """ - obj = await self.unmarshaler.unmarshal(message) - return obj + + async def _receive() -> AsyncIterator[Any]: + obj = await self.unmarshaler.unmarshal(message) + yield obj + + return _receive() @property def request_headers(self) -> Headers: diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 1d5fd3e..9a76dc7 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -10,7 +10,7 @@ from connect.code import Code from connect.codec import Codec, CodecNameType from connect.compression import COMPRESSION_IDENTITY, Compression -from connect.connect import Address, Peer, Spec, StreamingHandlerConn, UnaryHandlerConn +from connect.connect import Address, Peer, Spec, StreamingHandlerConn, StreamType from connect.envelope import EnvelopeReader, EnvelopeWriter from connect.error import ConnectError from connect.headers import Headers @@ -78,12 +78,16 @@ def handler(self, params: ProtocolHandlerParams) -> ProtocolHandler: return GRPCHandler(params, self.web, content_types) def client(self, params: ProtocolClientParams) -> ProtocolClient: - """Implement client functionality. + """Create and return a GRPCClient instance. + + Args: + params (ProtocolClientParams): The parameters required to initialize the client. + + Returns: + ProtocolClient: An instance of GRPCClient. - This method currently does nothing and is intended to be implemented - in the future with the necessary client-side logic. """ - raise NotImplementedError() + raise NotImplementedError("GRPC client is not implemented yet.") class GRPCHandler(ProtocolHandler): @@ -107,60 +111,26 @@ def can_handle_payload(self, _: Request, content_type: str) -> bool: return content_type in self.accept async def conn( - self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter - ) -> UnaryHandlerConn | None: - content_encoding = request.headers.get(GRPC_HEADER_COMPRESSION) - accept_encoding = request.headers.get(GRPC_HEADER_ACCEPT_COMPRESSION) - - request_compression, response_compression, error = negotiate_compression( - self.params.compressions, content_encoding, accept_encoding - ) - - response_headers[HEADER_CONTENT_TYPE] = request.headers.get(HEADER_CONTENT_TYPE, "") - response_headers[GRPC_HEADER_ACCEPT_COMPRESSION] = f"{', '.join(c.name for c in self.params.compressions)}" - if response_compression and response_compression.name != COMPRESSION_IDENTITY: - response_headers[GRPC_HEADER_COMPRESSION] = response_compression.name - - codec_name = grpc_codec_from_content_type(self.web, request.headers.get(HEADER_CONTENT_TYPE, "")) - codec = self.params.codecs.get(codec_name) - protocol_name = PROTOCOL_GRPC if not self.web else PROTOCOL_GRPC + "-web" - - peer = Peer( - address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, - protocol=protocol_name, - query=request.query_params, - ) + self, + request: Request, + response_headers: Headers, + response_trailers: Headers, + writer: ServerResponseWriter, + is_streaming: bool = False, + ) -> StreamingHandlerConn | None: + """Handle a connection request. - conn = GRPCHandlerConn( - writer=writer, - spec=self.params.spec, - peer=peer, - marshaler=GRPCMarshaler( - self.web, - codec, - response_compression, - self.params.compress_min_bytes, - self.params.send_max_bytes, - ), - unmarshaler=GRPCUnmarshaler( - codec, - self.params.read_max_bytes, - request.stream(), - request_compression, - ), - request_headers=Headers(request.headers, encoding="latin-1"), - response_headers=response_headers, - response_trailers=response_trailers, - ) - if error: - await conn.send_error(error) - return None + Args: + request (Request): The incoming request object. + response_headers (Headers): The headers to be sent in the response. + response_trailers (Headers): The trailers to be sent in the response. + writer (ServerResponseWriter): The writer used to send the response. + is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. - return conn + Returns: + StreamingHandlerConn | None: The connection handler or None if not implemented. - async def stream_conn( - self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter - ) -> StreamingHandlerConn | None: + """ content_encoding = request.headers.get(GRPC_HEADER_COMPRESSION) accept_encoding = request.headers.get(GRPC_HEADER_ACCEPT_COMPRESSION) @@ -183,7 +153,8 @@ async def stream_conn( query=request.query_params, ) - conn = GRPCStreamingHandlerConn( + # Create a single unified handler class with streaming flag + conn = GRPCHandlerConn( writer=writer, spec=self.params.spec, peer=peer, @@ -203,7 +174,9 @@ async def stream_conn( request_headers=Headers(request.headers, encoding="latin-1"), response_headers=response_headers, response_trailers=response_trailers, + is_streaming=is_streaming, ) + if error: await conn.send_error(error) return None @@ -245,7 +218,7 @@ async def unmarshal(self, message: Any) -> AsyncIterator[Any]: yield obj -class GRPCHandlerConn(UnaryHandlerConn): +class GRPCHandlerConn(StreamingHandlerConn): _spec: Spec _peer: Peer writer: ServerResponseWriter @@ -254,6 +227,7 @@ class GRPCHandlerConn(UnaryHandlerConn): _request_headers: Headers _response_headers: Headers _response_trailers: Headers + _is_streaming: bool def __init__( self, @@ -265,6 +239,7 @@ def __init__( request_headers: Headers, response_headers: Headers, response_trailers: Headers | None = None, + is_streaming: bool = False, ) -> None: self.writer = writer self._spec = spec @@ -274,6 +249,7 @@ def __init__( self._request_headers = request_headers self._response_headers = response_headers self._response_trailers = response_trailers if response_trailers is not None else Headers() + self._is_streaming = is_streaming def parse_timeout(self) -> float | None: timeout = self._request_headers.get(GRPC_HEADER_TIMEOUT) @@ -304,131 +280,72 @@ def spec(self) -> Spec: def peer(self) -> Peer: return self._peer - async def receive(self, message: Any) -> Any: - first = None - async for obj in self.unmarshaler.unmarshal(message): - # TODO(tsubakiky): validation - if first is None: - first = obj - else: - raise ConnectError("protocol error: expected only one message, but got multiple", Code.UNIMPLEMENTED) + def receive(self, message: Any) -> AsyncIterator[Any]: + """Receives a message and processes it. - if first is None: - raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) + Args: + message (Any): The message to be received and processed. - return first + Returns: + AsyncIterator[Any]: An async iterator yielding message(s). For non-streaming operations, + this will yield exactly one item. + + """ + # Different behavior based on streaming mode + if not self._is_streaming and self.spec.stream_type == StreamType.Unary: + + async def _receive_unary() -> AsyncIterator[Any]: + count = 0 + async for obj in self.unmarshaler.unmarshal(message): + count += 1 + if count > 1: + raise ConnectError( + "protocol error: expected only one message, but got multiple", Code.UNIMPLEMENTED + ) + yield obj + + if count == 0: + raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) + + return _receive_unary() + else: + # Streaming mode - simply pass through all messages + async def _receive_streaming() -> AsyncIterator[Any]: + async for obj in self.unmarshaler.unmarshal(message): + # TODO(tsubakiky): validation + yield obj + + return _receive_streaming() @property def request_headers(self) -> Headers: return self._request_headers - async def send(self, message: Any) -> None: - async def iterator() -> AsyncIterator[bytes]: - error: ConnectError | None = None - try: - async for msg in self.marshaler.marshal(aiterate([message])): - yield msg - except Exception as e: - error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) - finally: - grpc_error_to_trailer(self.response_trailers, error) - - await self.writer.write( - StreamingResponseWithTrailers( - content=iterator(), - headers=self.response_headers, - trailers=self.response_trailers, - status_code=200, - ) - ) - - @property - def response_headers(self) -> Headers: - return self._response_headers - - @property - def response_trailers(self) -> Headers: - return self._response_trailers - - async def send_error(self, error: ConnectError) -> None: - grpc_error_to_trailer(self.response_trailers, error) - - await self.writer.write( - StreamingResponseWithTrailers( - content=aiterate([b""]), headers=self.response_headers, trailers=self.response_trailers, status_code=200 - ) - ) - - -class GRPCStreamingHandlerConn(StreamingHandlerConn): - _spec: Spec - _peer: Peer - writer: ServerResponseWriter - marshaler: GRPCMarshaler - unmarshaler: GRPCUnmarshaler - _request_headers: Headers - _response_headers: Headers - _response_trailers: Headers - - def __init__( - self, - writer: ServerResponseWriter, - spec: Spec, - peer: Peer, - marshaler: GRPCMarshaler, - unmarshaler: GRPCUnmarshaler, - request_headers: Headers, - response_headers: Headers, - response_trailers: Headers | None = None, - ) -> None: - self.writer = writer - self._spec = spec - self._peer = peer - self.marshaler = marshaler - self.unmarshaler = unmarshaler - self._request_headers = request_headers - self._response_headers = response_headers - self._response_trailers = response_trailers if response_trailers is not None else Headers() - - def parse_timeout(self) -> float | None: - timeout = self._request_headers.get(GRPC_HEADER_TIMEOUT) - if not timeout: - return None - - m = _RE.match(timeout) - if m is None: - raise ConnectError(f"protocol error: invalid grpc timeout value: {timeout}") - - num_str, unit = m.groups() - num = int(num_str) - - if num > 99_999_999: - raise ConnectError(f"protocol error: timeout {timeout!r} is too long") - - if unit == "H" and num > _MAX_HOURS: - return None + async def send(self, messages: AsyncIterable[Any]) -> None: + """Send message(s) by marshaling them into bytes. - seconds = num * _UNIT_TO_SECONDS[unit] - return seconds + Args: + messages (AsyncIterable[Any]): The message(s) to be sent. For unary operations, + this should be an iterable with a single item. - @property - def spec(self) -> Spec: - return self._spec + Returns: + None - @property - def peer(self) -> Peer: - return self._peer + """ + # Validation for unary streams - ensure exactly one message + if not self._is_streaming and self.spec.stream_type == StreamType.Unary: + message_list = [] + async for msg in messages: + message_list.append(msg) - async def receive(self, message: Any) -> AsyncIterator[Any]: - async for obj in self.unmarshaler.unmarshal(message): - # TODO(tsubakiky): validation - yield obj + if len(message_list) != 1: + raise ConnectError( + f"unary handler expected to send exactly one message, got {len(message_list)}", Code.INTERNAL + ) - @property - def request_headers(self) -> Headers: - return self._request_headers + messages = aiterate(message_list) - async def send(self, messages: AsyncIterable[Any]) -> None: + # Common sending logic for both streaming and unary async def iterator() -> AsyncIterator[bytes]: error: ConnectError | None = None try: From 5e095ac4551862596438c09a6f9164a7b85a7fa4 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Sat, 26 Apr 2025 11:47:51 +0900 Subject: [PATCH 2/4] protocol_grpc: fix --- src/connect/handler.py | 227 ++++++++++++++++++-------------- src/connect/protocol_connect.py | 63 +++++---- src/connect/protocol_grpc.py | 92 ++++++++----- 3 files changed, 227 insertions(+), 155 deletions(-) diff --git a/src/connect/handler.py b/src/connect/handler.py index 073e39b..6a3c221 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -4,7 +4,7 @@ import logging from collections.abc import Awaitable, Callable from http import HTTPMethod, HTTPStatus -from typing import Any, TypeGuard +from typing import Any import anyio from starlette.responses import PlainTextResponse @@ -147,16 +147,11 @@ def create_protocol_handlers(config: HandlerConfig) -> list[ProtocolHandler]: return handlers -UnaryImplementationFunc = Callable[[StreamingHandlerConn, float | None], Awaitable[None]] -StreamImplementationFunc = Callable[[StreamingHandlerConn, float | None], Awaitable[None]] - - class Handler: """A class to handle incoming HTTP requests and generate appropriate HTTP responses. Attributes: procedure (str): The procedure name. - implementation (UnaryImplementationFunc | StreamImplementationFunc): The implementation function for handling requests. protocol_handlers (dict[HTTPMethod, list[ProtocolHandler]]): A dictionary mapping HTTP methods to protocol handlers. allow_methods (str): Allowed HTTP methods. accept_post (str): Accepted content types for POST requests. @@ -165,7 +160,6 @@ class Handler: """ procedure: str - implementation: UnaryImplementationFunc | StreamImplementationFunc protocol_handlers: dict[HTTPMethod, list[ProtocolHandler]] allow_methods: str accept_post: str @@ -174,7 +168,6 @@ class Handler: def __init__( self, procedure: str, - implementation: UnaryImplementationFunc | StreamImplementationFunc, protocol_handlers: dict[HTTPMethod, list[ProtocolHandler]], allow_methods: str, accept_post: str, @@ -183,18 +176,31 @@ def __init__( Args: procedure (str): The name of the procedure. - implementation (UnaryImplementationFunc | StreamImplementationFunc): The function implementing the procedure. protocol_handlers (dict[HTTPMethod, list[ProtocolHandler]]): A dictionary mapping HTTP methods to protocol handlers. allow_methods (str): A string specifying allowed HTTP methods. accept_post (str): A string specifying if POST method is accepted. """ self.procedure = procedure - self.implementation = implementation self.protocol_handlers = protocol_handlers self.allow_methods = allow_methods self.accept_post = accept_post + async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: + """Handle the implementation of the request processing. + + This method should be overridden by subclasses. + + Args: + conn: The connection handler + timeout: Optional timeout in milliseconds + + Raises: + NotImplementedError: If not implemented by a subclass + + """ + raise NotImplementedError("Implementation must be provided by subclass") + async def handle(self, request: Request) -> Response: """Handle an incoming HTTP request and return an HTTP response. @@ -254,6 +260,7 @@ async def handle(self, request: Request) -> Response: writer = ServerResponseWriter() + # Create tasks for handling the request and receiving responses main_task = asyncio.create_task(self._handle(request, response_headers, response_trailers, writer)) writer_task = asyncio.create_task(writer.receive()) @@ -298,33 +305,25 @@ async def _handle( self._is_stream_handler = False await self.unary_handle(request, response_headers, response_trailers, writer) - def is_stream( - self, impl: UnaryImplementationFunc | StreamImplementationFunc - ) -> TypeGuard[StreamImplementationFunc]: - """Determine if the given implementation function is a stream implementation. - - Args: - impl (UnaryImplementationFunc | StreamImplementationFunc): The implementation function to check. + def is_stream(self) -> bool: + """Determine if this handler is a stream handler. Returns: - TypeGuard[StreamImplementationFunc]: True if the implementation function is a stream implementation, False otherwise. + bool: True if this is a stream handler, False otherwise. """ - # Since we've consolidated to a single connection type, use a sentinel value in the handler config + # Since we've consolidated to a single connection type, use a sentinel value is_stream_handler = getattr(self, "_is_stream_handler", False) return is_stream_handler - def is_unary(self, impl: UnaryImplementationFunc | StreamImplementationFunc) -> TypeGuard[UnaryImplementationFunc]: - """Determine if the given implementation function is a unary implementation. - - Args: - impl (UnaryImplementationFunc | StreamImplementationFunc): The implementation function to check. + def is_unary(self) -> bool: + """Determine if this handler is a unary handler. Returns: - TypeGuard[UnaryImplementationFunc]: True if the implementation function is a unary implementation, False otherwise. + bool: True if this is a unary handler, False otherwise. """ - # Since we've consolidated to a single connection type, use a sentinel value in the handler config + # Since we've consolidated to a single connection type, use a sentinel value is_stream_handler = getattr(self, "_is_stream_handler", False) return not is_stream_handler @@ -343,7 +342,7 @@ async def stream_handle( None Raises: - ValueError: If the function type for the stream handler is invalid. + ValueError: If the implementation method is invalid. ConnectError: If an internal error occurs during the handling of the stream. """ @@ -352,23 +351,16 @@ async def stream_handle( if conn is None: return - implementation = self.implementation - if not self.is_stream(implementation): - raise ValueError(f"Invalid function type for stream handler: {implementation}") - try: timeout = conn.parse_timeout() if timeout: timeout_ms = int(timeout * 1000) - with anyio.fail_after(delay=timeout): - await implementation(conn, timeout_ms) + await self.implementation(conn, timeout_ms) else: - await implementation(conn, None) - + await self.implementation(conn, None) except Exception as e: error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) - await conn.send_error(error) async def unary_handle( @@ -383,7 +375,7 @@ async def unary_handle( writer (ServerResponseWriter): The writer to send the response. Raises: - ValueError: If the function type is invalid for unary handler. + ValueError: If the implementation method is invalid. ConnectError: If there is an error parsing the timeout or an internal error occurs. Returns: @@ -395,19 +387,14 @@ async def unary_handle( if conn is None: return - implementation = self.implementation - if not self.is_unary(implementation): - raise ValueError(f"Invalid function type for unary handler: {implementation}") - try: timeout = conn.parse_timeout() if timeout: timeout_ms = int(timeout * 1000) with anyio.fail_after(delay=timeout): - await implementation(conn, timeout_ms) + await self.implementation(conn, timeout_ms) else: - await implementation(conn, None) - + await self.implementation(conn, None) except Exception as e: error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) @@ -463,31 +450,44 @@ async def _untyped(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Response untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None: - request = await receive_unary_request(conn, input) - if timeout: - request.timeout = timeout - - response = await untyped(request) - - if not isinstance(response.message, output): - raise ConnectError( - f"expected response of type: {output.__name__}", - Code.INTERNAL, - ) - - conn.response_headers.update(exclude_protocol_headers(response.headers)) - conn.response_trailers.update(exclude_protocol_headers(response.trailers)) - await conn.send(aiterate([response.message])) + self.input = input + self.output = output + self.untyped = untyped super().__init__( procedure=procedure, - implementation=implementation, protocol_handlers=mapped_method_handlers(protocol_handlers), allow_methods=sorted_allow_method_value(protocol_handlers), accept_post=sorted_accept_post_value(protocol_handlers), ) + async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: + """Handle a unary request and send a unary response. + + Args: + conn: The connection handler + timeout: Optional timeout in milliseconds + + Raises: + ConnectError: If the response message type is incorrect + + """ + request = await receive_unary_request(conn, self.input) + if timeout: + request.timeout = timeout + + response = await self.untyped(request) + + if not isinstance(response.message, self.output): + raise ConnectError( + f"expected response of type: {self.output.__name__}", + Code.INTERNAL, + ) + + conn.response_headers.update(exclude_protocol_headers(response.headers)) + conn.response_trailers.update(exclude_protocol_headers(response.trailers)) + await conn.send(aiterate([response.message])) + class ServerStreamHandler[T_Request, T_Response](Handler): """A handler for server-side streaming RPCs. @@ -532,31 +532,40 @@ def __init__( async def _untyped(request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: response = await stream(request) - return response untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None: - request = await receive_stream_request(conn, input) - if timeout: - request.timeout = timeout - - response = await untyped(request) - - conn.response_headers.update(response.headers) - conn.response_trailers.update(response.trailers) - - await conn.send(response.messages) + self.input = input + self.output = output + self.untyped = untyped super().__init__( procedure=procedure, - implementation=implementation, protocol_handlers=mapped_method_handlers(protocol_handlers), allow_methods=sorted_allow_method_value(protocol_handlers), accept_post=sorted_accept_post_value(protocol_handlers), ) + async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: + """Handle a server stream request and response. + + Args: + conn: The connection handler + timeout: Optional timeout in milliseconds + + """ + request = await receive_stream_request(conn, self.input) + if timeout: + request.timeout = timeout + + response = await self.untyped(request) + + conn.response_headers.update(response.headers) + conn.response_trailers.update(response.trailers) + + await conn.send(response.messages) + class ClientStreamHandler[T_Request, T_Response](Handler): """A handler for client-side streaming RPCs. @@ -609,31 +618,40 @@ def __init__( async def _untyped(request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: response = await stream(request) - return response untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None: - request = await receive_stream_request(conn, input) - if timeout: - request.timeout = timeout - - response = await untyped(request) - - conn.response_headers.update(response.headers) - conn.response_trailers.update(response.trailers) - - await conn.send(response.messages) + self.input = input + self.output = output + self.untyped = untyped super().__init__( procedure=procedure, - implementation=implementation, protocol_handlers=mapped_method_handlers(protocol_handlers), allow_methods=sorted_allow_method_value(protocol_handlers), accept_post=sorted_accept_post_value(protocol_handlers), ) + async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: + """Handle a client stream request and response. + + Args: + conn: The connection handler + timeout: Optional timeout in milliseconds + + """ + request = await receive_stream_request(conn, self.input) + if timeout: + request.timeout = timeout + + response = await self.untyped(request) + + conn.response_headers.update(response.headers) + conn.response_trailers.update(response.trailers) + + await conn.send(response.messages) + class BidiStreamHandler[T_Request, T_Response](Handler): """A handler for bidirectional streaming RPCs in a Connect-based framework. @@ -705,27 +723,36 @@ def __init__( async def _untyped(request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: response = await stream(request) - return response untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None: - request = await receive_stream_request(conn, input) - if timeout: - request.timeout = timeout - - response = await untyped(request) - - conn.response_headers.update(response.headers) - conn.response_trailers.update(response.trailers) - - await conn.send(response.messages) + self.input = input + self.output = output + self.untyped = untyped super().__init__( procedure=procedure, - implementation=implementation, protocol_handlers=mapped_method_handlers(protocol_handlers), allow_methods=sorted_allow_method_value(protocol_handlers), accept_post=sorted_accept_post_value(protocol_handlers), ) + + async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: + """Handle a bidirectional stream request and response. + + Args: + conn: The connection handler + timeout: Optional timeout in milliseconds + + """ + request = await receive_stream_request(conn, self.input) + if timeout: + request.timeout = timeout + + response = await self.untyped(request) + + conn.response_headers.update(response.headers) + conn.response_trailers.update(response.trailers) + + await conn.send(response.messages) diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 15742d9..8dbf65e 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -674,8 +674,8 @@ def peer(self) -> Peer: """ return self._peer - def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message, unmarshals it, and returns the resulting object. + async def receive_message(self, message: Any) -> AsyncIterator[Any]: + """Receives and unmarshals a message into an object. Args: message (Any): The message to be unmarshaled. @@ -684,12 +684,20 @@ def receive(self, message: Any) -> AsyncIterator[Any]: AsyncIterator[Any]: An async iterator yielding the unmarshaled object. """ + obj = await self.unmarshaler.unmarshal(message) + yield obj - async def _receive() -> AsyncIterator[Any]: - obj = await self.unmarshaler.unmarshal(message) - yield obj + def receive(self, message: Any) -> AsyncIterator[Any]: + """Receives a message, unmarshals it, and returns the resulting object. - return _receive() + Args: + message (Any): The message to be unmarshaled. + + Returns: + AsyncIterator[Any]: An async iterator yielding the unmarshaled object. + + """ + return self.receive_message(message) @property def request_headers(self) -> Headers: @@ -1549,6 +1557,32 @@ def request_headers(self) -> Headers: """ return self._request_headers + async def create_message_iterator(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: + """Create an async iterator that marshals messages with error handling. + + Args: + messages (AsyncIterable[Any]): Messages to marshal + + Returns: + AsyncIterator[bytes]: Marshaled bytes with end stream message + + Yields: + bytes: Each marshaled message followed by an end stream message + + """ + error: ConnectError | None = None + try: + async for message in self.marshaler.marshal(messages): + yield message + except Exception as e: + error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) + finally: + json_obj = end_stream_to_json(error, self.response_trailers) + json_str = json.dumps(json_obj) + + body = self.marshaler.marshal_end_stream(json_str.encode()) + yield body + async def send(self, messages: AsyncIterable[Any]) -> None: """Send a stream of messages asynchronously. @@ -1566,24 +1600,9 @@ async def send(self, messages: AsyncIterable[Any]) -> None: ConnectError: If an error occurs during the marshaling process. """ - - async def iterator() -> AsyncIterator[bytes]: - error: ConnectError | None = None - try: - async for message in self.marshaler.marshal(messages): - yield message - except Exception as e: - error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) - finally: - json_obj = end_stream_to_json(error, self.response_trailers) - json_str = json.dumps(json_obj) - - body = self.marshaler.marshal_end_stream(json_str.encode()) - yield body - await self.writer.write( StreamingResponse( - content=iterator(), + content=self.create_message_iterator(messages), headers=self.response_headers, status_code=200, ) diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 9a76dc7..6068834 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -280,6 +280,42 @@ def spec(self) -> Spec: def peer(self) -> Peer: return self._peer + async def receive_unary(self, message: Any) -> AsyncIterator[Any]: + """Receive a single message in unary mode. + + Args: + message (Any): The message to be received and processed. + + Returns: + AsyncIterator[Any]: An async iterator yielding exactly one item. + + Raises: + ConnectError: If no message is received or multiple messages are received. + + """ + count = 0 + async for obj in self.unmarshaler.unmarshal(message): + count += 1 + if count > 1: + raise ConnectError("protocol error: expected only one message, but got multiple", Code.UNIMPLEMENTED) + yield obj + + if count == 0: + raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) + + async def receive_streaming(self, message: Any) -> AsyncIterator[Any]: + """Receive messages in streaming mode. + + Args: + message (Any): The message to be received and processed. + + Returns: + AsyncIterator[Any]: An async iterator yielding all messages. + + """ + async for obj in self.unmarshaler.unmarshal(message): + yield obj + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message and processes it. @@ -293,29 +329,10 @@ def receive(self, message: Any) -> AsyncIterator[Any]: """ # Different behavior based on streaming mode if not self._is_streaming and self.spec.stream_type == StreamType.Unary: - - async def _receive_unary() -> AsyncIterator[Any]: - count = 0 - async for obj in self.unmarshaler.unmarshal(message): - count += 1 - if count > 1: - raise ConnectError( - "protocol error: expected only one message, but got multiple", Code.UNIMPLEMENTED - ) - yield obj - - if count == 0: - raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) - - return _receive_unary() + return self.receive_unary(message) else: # Streaming mode - simply pass through all messages - async def _receive_streaming() -> AsyncIterator[Any]: - async for obj in self.unmarshaler.unmarshal(message): - # TODO(tsubakiky): validation - yield obj - - return _receive_streaming() + return self.receive_streaming(message) @property def request_headers(self) -> Headers: @@ -345,20 +362,10 @@ async def send(self, messages: AsyncIterable[Any]) -> None: messages = aiterate(message_list) - # Common sending logic for both streaming and unary - async def iterator() -> AsyncIterator[bytes]: - error: ConnectError | None = None - try: - async for msg in self.marshaler.marshal(messages): - yield msg - except Exception as e: - error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) - finally: - grpc_error_to_trailer(self.response_trailers, error) - + # Use the common marshaling method await self.writer.write( StreamingResponseWithTrailers( - content=iterator(), + content=self.marshal_with_error_handling(messages), headers=self.response_headers, trailers=self.response_trailers, status_code=200, @@ -373,6 +380,25 @@ def response_headers(self) -> Headers: def response_trailers(self) -> Headers: return self._response_trailers + async def marshal_with_error_handling(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: + """Marshal messages to bytes with error handling. + + Args: + messages (AsyncIterable[Any]): The messages to marshal + + Returns: + AsyncIterator[bytes]: An async iterator of marshaled bytes + + """ + error: ConnectError | None = None + try: + async for msg in self.marshaler.marshal(messages): + yield msg + except Exception as e: + error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) + finally: + grpc_error_to_trailer(self.response_trailers, error) + async def send_error(self, error: ConnectError) -> None: grpc_error_to_trailer(self.response_trailers, error) From 5419aaae5c446878131e98cf3ecc7fff4d3e1f10 Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Wed, 30 Apr 2025 23:55:48 +0900 Subject: [PATCH 3/4] connect: async content stream with stream validation --- conformance/run-testcase.txt | 2 +- src/connect/connect.py | 86 ++++++++++++++++++--------------- src/connect/protocol_connect.py | 41 +++++++++++++--- src/connect/protocol_grpc.py | 47 ++---------------- 4 files changed, 87 insertions(+), 89 deletions(-) diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt index b42c6a1..26b35a5 100644 --- a/conformance/run-testcase.txt +++ b/conformance/run-testcase.txt @@ -1 +1 @@ -gRPC Unexpected Requests/HTTPVersion:2/TLS:false/unary/no-request +gRPC Unexpected Requests/HTTPVersion:2/TLS:true/unary/multiple-requests diff --git a/src/connect/connect.py b/src/connect/connect.py index 4692a62..4a4bfea 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -318,6 +318,48 @@ async def aclose(self) -> None: await aclose() +class AsyncContentStream[T](AsyncIterable[T]): + """AsyncContentStream is a wrapper around AsyncIterable to provide a consistent interface.""" + + def __init__(self, iterable: AsyncIterable[T], stream_type: StreamType) -> None: + self._iterable = iterable + self.stream_type = stream_type + + def _needs_single_message_validation(self) -> bool: + return self.stream_type == StreamType.Unary or self.stream_type == StreamType.ServerStream + + async def _ensure_single(self, iterable: AsyncIterable[T]) -> AsyncIterator[T]: + count = 0 + async for item in iterable: + if count > 0: + raise ConnectError("protocol error: expected only one message, but got multiple", Code.UNIMPLEMENTED) + + yield item + count += 1 + + if count == 0: + raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) + + async def __aiter__(self) -> AsyncIterator[T]: + """Asynchronously iterate over the content stream.""" + if self._needs_single_message_validation(): + async for item in self._ensure_single(self._iterable): + yield item + else: + async for item in self._iterable: + yield item + + async def ensure_single(self) -> T: + message = None + async for item in self._ensure_single(self._iterable): + message = item + + if message is None: + raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) + + return message + + class StreamingHandlerConn(abc.ABC): """Abstract base class for a streaming handler connection. @@ -356,7 +398,7 @@ def peer(self) -> Peer: raise NotImplementedError() @abc.abstractmethod - def receive(self, message: Any) -> AsyncIterator[Any]: + def receive(self, message: Any) -> AsyncContentStream[Any]: """Receives a message and processes it. Args: @@ -598,7 +640,8 @@ async def receive_unary_request[T](conn: StreamingHandlerConn, t: type[T]) -> Un UnaryRequest[T]: A UnaryRequest object containing the received message. """ - message = await receive_unary_message(conn, t) + stream = conn.receive(t) + message = await stream.ensure_single() method = HTTPMethod.POST get_http_method = get_callable_attribute(conn, "get_http_method") @@ -626,8 +669,10 @@ async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> S peer information, request headers, and HTTP method. """ + stream = conn.receive(t) + return StreamRequest( - messages=receive_stream_message(conn, t, conn.spec), + messages=stream, spec=conn.spec, peer=conn.peer, headers=conn.request_headers, @@ -635,41 +680,6 @@ async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> S ) -async def receive_stream_message[T](conn: StreamingHandlerConn, t: type[T], spec: Spec) -> AsyncIterator[T]: - """Asynchronously receives and yields messages from a streaming connection. - - This function listens to a streaming connection and yields messages of the specified type. - - Args: - conn (StreamingHandlerConn): The streaming connection handler. - t (type[T]): The type of messages to receive. - spec (Spec): The specification for the request. - - Yields: - AsyncIterator[T]: An asynchronous iterator of messages of type T. - - """ - if spec.stream_type == StreamType.ServerStream: - count = 0 - async for message in conn.receive(t): - count += 1 - if count > 1: - raise ConnectError( - f"received extra input message for {conn.spec.procedure} method", - Code.UNIMPLEMENTED, - ) - yield message - - if count == 0: - raise ConnectError( - f"missing input message for {conn.spec.procedure} method", - Code.UNIMPLEMENTED, - ) - else: - async for message in conn.receive(t): - yield message - - async def recieve_unary_response[T](conn: UnaryClientConn, t: type[T]) -> UnaryResponse[T]: """Receive a unary response from a streaming client connection. diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index 8dbf65e..c083723 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -26,6 +26,7 @@ from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name from connect.connect import ( Address, + AsyncContentStream, Peer, Spec, StreamingClientConn, @@ -687,7 +688,7 @@ async def receive_message(self, message: Any) -> AsyncIterator[Any]: obj = await self.unmarshaler.unmarshal(message) yield obj - def receive(self, message: Any) -> AsyncIterator[Any]: + def receive(self, message: Any) -> AsyncContentStream[Any]: """Receives a message, unmarshals it, and returns the resulting object. Args: @@ -697,7 +698,10 @@ def receive(self, message: Any) -> AsyncIterator[Any]: AsyncIterator[Any]: An async iterator yielding the unmarshaled object. """ - return self.receive_message(message) + return AsyncContentStream( + self.receive_message(message), + stream_type=self.spec.stream_type, + ) @property def request_headers(self) -> Headers: @@ -1534,19 +1538,44 @@ def peer(self) -> Peer: """ return self._peer - async def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message, unmarshals it, and returns the resulting object. + async def receive_message(self, message: Any) -> AsyncIterator[Any]: + """Asynchronously receives a message and yields unmarshaled objects. + + This method unmarshals the received message and yields each + unmarshaled object one by one as an asynchronous iterator. Args: - message (Any): The message to be unmarshaled. + message (Any): The message to unmarshal. Returns: - Any: The unmarshaled object. + AsyncIterator[Any]: An asynchronous iterator yielding unmarshaled objects. + + Yields: + Any: Each unmarshaled object from the message. """ async for obj, _ in self.unmarshaler.unmarshal(message): yield obj + def receive(self, message: Any) -> AsyncContentStream[Any]: + """Receives a message and returns an asynchronous content stream. + + This method processes the incoming message through the receive_message method + and wraps the result in an AsyncContentStream with the appropriate stream type. + + Args: + message (Any): The message to be processed. + + Returns: + AsyncContentStream[Any]: An asynchronous stream of content based on the + processed message, configured with the specification's stream type. + + """ + return AsyncContentStream( + iterable=self.receive_message(message), + stream_type=self.spec.stream_type, + ) + @property def request_headers(self) -> Headers: """Retrieve the headers from the request. diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 6068834..3c05822 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -10,7 +10,7 @@ from connect.code import Code from connect.codec import Codec, CodecNameType from connect.compression import COMPRESSION_IDENTITY, Compression -from connect.connect import Address, Peer, Spec, StreamingHandlerConn, StreamType +from connect.connect import Address, AsyncContentStream, Peer, Spec, StreamingHandlerConn, StreamType from connect.envelope import EnvelopeReader, EnvelopeWriter from connect.error import ConnectError from connect.headers import Headers @@ -280,43 +280,7 @@ def spec(self) -> Spec: def peer(self) -> Peer: return self._peer - async def receive_unary(self, message: Any) -> AsyncIterator[Any]: - """Receive a single message in unary mode. - - Args: - message (Any): The message to be received and processed. - - Returns: - AsyncIterator[Any]: An async iterator yielding exactly one item. - - Raises: - ConnectError: If no message is received or multiple messages are received. - - """ - count = 0 - async for obj in self.unmarshaler.unmarshal(message): - count += 1 - if count > 1: - raise ConnectError("protocol error: expected only one message, but got multiple", Code.UNIMPLEMENTED) - yield obj - - if count == 0: - raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) - - async def receive_streaming(self, message: Any) -> AsyncIterator[Any]: - """Receive messages in streaming mode. - - Args: - message (Any): The message to be received and processed. - - Returns: - AsyncIterator[Any]: An async iterator yielding all messages. - - """ - async for obj in self.unmarshaler.unmarshal(message): - yield obj - - def receive(self, message: Any) -> AsyncIterator[Any]: + def receive(self, message: Any) -> AsyncContentStream[Any]: """Receives a message and processes it. Args: @@ -327,12 +291,7 @@ def receive(self, message: Any) -> AsyncIterator[Any]: this will yield exactly one item. """ - # Different behavior based on streaming mode - if not self._is_streaming and self.spec.stream_type == StreamType.Unary: - return self.receive_unary(message) - else: - # Streaming mode - simply pass through all messages - return self.receive_streaming(message) + return AsyncContentStream(self.unmarshaler.unmarshal(message), self.spec.stream_type) @property def request_headers(self) -> Headers: From 7557c89ba29e537cfbc9bc7aa17da7f912c4565a Mon Sep 17 00:00:00 2001 From: tsubakiky Date: Thu, 1 May 2025 10:42:02 +0900 Subject: [PATCH 4/4] all: update doc --- src/connect/connect.py | 179 ++++++++++++---------- src/connect/envelope.py | 35 ++++- src/connect/handler.py | 19 +-- src/connect/protocol_connect.py | 40 +++-- src/connect/protocol_grpc.py | 264 +++++++++++++++++++++++++++++++- src/connect/response_trailer.py | 50 +++++- 6 files changed, 477 insertions(+), 110 deletions(-) diff --git a/src/connect/connect.py b/src/connect/connect.py index 4a4bfea..a5ede48 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Mapping from enum import Enum from http import HTTPMethod -from typing import Any, Protocol, cast +from typing import Any, cast from pydantic import BaseModel @@ -319,16 +319,63 @@ async def aclose(self) -> None: class AsyncContentStream[T](AsyncIterable[T]): - """AsyncContentStream is a wrapper around AsyncIterable to provide a consistent interface.""" + """AsyncContentStream is a generic asynchronous stream wrapper for async iterables, providing validation and iteration utilities based on stream type. + + Type Parameters: + T: The type of elements yielded by the asynchronous iterable. + + iterable (AsyncIterable[T]): The asynchronous iterable to wrap. + stream_type (StreamType): The type of stream (e.g., Unary, ServerStream) that determines validation behavior. + + Attributes: + _iterable (AsyncIterable[T]): The underlying asynchronous iterable. + stream_type (StreamType): The type of stream this instance represents. + + """ def __init__(self, iterable: AsyncIterable[T], stream_type: StreamType) -> None: + """Initialize a stream wrapper for an async iterable. + + This constructor stores the provided async iterable and its corresponding + stream type for later processing. + + Args: + iterable: An asynchronous iterable containing elements of type T. + stream_type: The type of stream this iterable represents. + + Returns: + None + + """ self._iterable = iterable self.stream_type = stream_type - def _needs_single_message_validation(self) -> bool: + def _needs_single_content_validation(self) -> bool: + """Determine if single message validation is required based on the stream type. + + Returns: + bool: True if the stream type is Unary or ServerStream, indicating that single message validation is needed; False otherwise. + + """ return self.stream_type == StreamType.Unary or self.stream_type == StreamType.ServerStream - async def _ensure_single(self, iterable: AsyncIterable[T]) -> AsyncIterator[T]: + async def _validate_single_content_stream(self, iterable: AsyncIterable[T]) -> AsyncIterator[T]: + """Validate that the given asynchronous iterable yields exactly one item. + + Iterates over the provided async iterable and ensures that it produces a single item. + If more than one item is yielded, raises a ConnectError indicating a protocol error. + If no items are yielded, also raises a ConnectError indicating a protocol error. + + Args: + iterable (AsyncIterable[T]): The asynchronous iterable to validate. + + Yields: + T: The single item from the iterable. + + Raises: + ConnectError: If the iterable yields zero or more than one item. + + """ count = 0 async for item in iterable: if count > 0: @@ -341,17 +388,38 @@ async def _ensure_single(self, iterable: AsyncIterable[T]) -> AsyncIterator[T]: raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) async def __aiter__(self) -> AsyncIterator[T]: - """Asynchronously iterate over the content stream.""" - if self._needs_single_message_validation(): - async for item in self._ensure_single(self._iterable): + """Asynchronously iterates over the underlying iterable. + + If single message validation is required, wraps the iterable with a validation step. + Otherwise, yields items directly from the iterable. + + Yields: + T: Items from the underlying asynchronous iterable. + + """ + if self._needs_single_content_validation(): + async for item in self._validate_single_content_stream(self._iterable): yield item else: async for item in self._iterable: yield item async def ensure_single(self) -> T: + """Asynchronously ensures that exactly one message is present in the iterable. + + Iterates over the provided asynchronous iterable and retrieves a single message. + Raises a ConnectError if no messages are found. If multiple messages are present, + only the last one will be returned (potential protocol error). + + Returns: + T: The single message retrieved from the iterable. + + Raises: + ConnectError: If no messages are found in the iterable. + + """ message = None - async for item in self._ensure_single(self._iterable): + async for item in self._validate_single_content_stream(self._iterable): message = item if message is None: @@ -399,14 +467,16 @@ def peer(self) -> Peer: @abc.abstractmethod def receive(self, message: Any) -> AsyncContentStream[Any]: - """Receives a message and processes it. + """Receives a message and returns an asynchronous content stream. Args: - message (Any): The message to be received and processed. + message (Any): The message to be processed. Returns: - AsyncIterator[Any]: An async iterator of processing results. - For unary operations, this will yield exactly one item. + AsyncContentStream[Any]: An asynchronous stream of content resulting from processing the message. + + Raises: + NotImplementedError: This method should be implemented by subclasses. """ raise NotImplementedError() @@ -477,15 +547,6 @@ async def send_error(self, error: ConnectError) -> None: """ raise NotImplementedError() - def is_unary(self) -> bool: - """Check if this connection is for a unary operation. - - Returns: - bool: True if this is a unary operation connection, False otherwise. - - """ - return self.spec.stream_type == StreamType.Unary - class UnaryClientConn: """Abstract base class for a streaming client connection.""" @@ -535,6 +596,11 @@ def on_request_send(self, fn: Callable[..., Any]) -> None: """Handle the request send event.""" raise NotImplementedError() + @abc.abstractmethod + async def aclose(self) -> None: + """Asynchronously close the connection.""" + raise NotImplementedError() + class StreamingClientConn: """Abstract base class for a streaming client connection.""" @@ -592,43 +658,6 @@ async def aclose(self) -> None: raise NotImplementedError() -class ReceiveConn(Protocol): - """A protocol that defines the methods required for receiving connections.""" - - @property - @abc.abstractmethod - def spec(self) -> Spec: - """Retrieve the specification for the current object. - - This method should be implemented by subclasses to return an instance - of the `Spec` class that defines the specification for the object. - - Raises: - NotImplementedError: If the method is not implemented by a subclass. - - Returns: - Spec: The specification for the current object. - - """ - raise NotImplementedError() - - @abc.abstractmethod - def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message and processes it. - - Args: - message (Any): The message to be received and processed. - - Returns: - AsyncIterator[Any]: An async iterator of processing results. - - Raises: - NotImplementedError: This method should be implemented by subclasses. - - """ - raise NotImplementedError() - - async def receive_unary_request[T](conn: StreamingHandlerConn, t: type[T]) -> UnaryRequest[T]: """Receives a unary request from the given connection and returns a UnaryRequest object. @@ -770,36 +799,22 @@ async def recieve_stream_response[T]( return StreamResponse(receive_stream, conn.response_headers, conn.response_trailers) -async def receive_unary_message[T](conn: ReceiveConn, t: type[T]) -> T: - """Receive a unary message from the given connection. +async def receive_unary_message[T](conn: UnaryClientConn, t: type[T]) -> T: + """Asynchronously receives a single unary message from the given connection. Args: - conn (ReceiveConn): The connection object to receive the message from. - t (type[T]): The type of the message to be received. + conn (UnaryClientConn): The unary client connection to receive the message from. + t (type[T]): The expected type of the message to be received. Returns: T: The received message of type T. Raises: - ConnectError: If no message is received or multiple messages are received. - - """ - first = None - count = 0 + Exception: If receiving the message fails or more than one message is received. - async for message in conn.receive(t): - count += 1 - if count > 1: - raise ConnectError( - f"received extra input message for {conn.spec.procedure} method", - Code.UNIMPLEMENTED, - ) - first = message + Note: + This function ensures that exactly one message is received from the connection. - if first is None: - raise ConnectError( - f"missing input message for {conn.spec.procedure} method", - Code.UNIMPLEMENTED, - ) - - return first + """ + single_message = await _receive_exactly_one(conn.receive(t), conn.aclose) + return single_message diff --git a/src/connect/envelope.py b/src/connect/envelope.py index 0a3b1ba..9d635b7 100644 --- a/src/connect/envelope.py +++ b/src/connect/envelope.py @@ -126,6 +126,27 @@ def is_set(self, flag: EnvelopeFlags) -> bool: class EnvelopeWriter: + """EnvelopeWriter is responsible for marshaling messages, optionally compressing them, and writing them into envelopes for transmission. + + Attributes: + codec (Codec | None): The codec used for encoding and decoding messages. + send_max_bytes (int): The maximum number of bytes allowed per message. + compression (Compression | None): The compression method to use, or None for no compression. + + Methods: + __init__(codec, compression, compress_min_bytes, send_max_bytes): + Initializes the EnvelopeWriter with the specified codec, compression, and size constraints. + + async _marshal(messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: + Asynchronously marshals and optionally compresses messages from an async iterable, yielding encoded envelope bytes. + Raises ConnectError if marshaling fails or message size exceeds the allowed limit. + + write_envelope(data: bytes, flags: EnvelopeFlags) -> Envelope: + Writes an envelope, optionally compressing its data if conditions are met, and updates envelope flags accordingly. + Raises ConnectError if the (compressed) message size exceeds the allowed maximum. + + """ + codec: Codec | None compress_min_bytes: int send_max_bytes: int @@ -174,16 +195,22 @@ async def _marshal(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: yield env.encode() def write_envelope(self, data: bytes, flags: EnvelopeFlags) -> Envelope: - """Write an envelope, optionally compressing its data if certain conditions are met. + """Write an envelope containing the provided data, applying compression if required. Args: - env (Envelope): The envelope to be written. + data (bytes): The message payload to be written into the envelope. + flags (EnvelopeFlags): Flags indicating envelope properties, such as compression. Returns: - Envelope: The envelope with possibly compressed data and updated flags. + Envelope: An envelope object containing the (optionally compressed) data and updated flags. Raises: - ConnectError: If the size of the envelope data exceeds the maximum allowed size. + ConnectError: If the (compressed or uncompressed) data size exceeds the configured send_max_bytes limit. + + Notes: + - Compression is applied only if the flags do not already indicate compression, + compression is enabled, and the data size exceeds the minimum threshold. + - The flags are updated to include the compressed flag if compression is performed. """ if flags in EnvelopeFlags.compressed or self.compression is None or len(data) < self.compress_min_bytes: diff --git a/src/connect/handler.py b/src/connect/handler.py index 6a3c221..61efdb2 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -297,13 +297,10 @@ async def handle(self, request: Request) -> Response: async def _handle( self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter ) -> None: - # Check the stream type of the handler - if getattr(self, "stream_type", StreamType.Unary) != StreamType.Unary: - self._is_stream_handler = True - await self.stream_handle(request, response_headers, response_trailers, writer) - else: - self._is_stream_handler = False + if getattr(self, "stream_type", StreamType.Unary) == StreamType.Unary: await self.unary_handle(request, response_headers, response_trailers, writer) + else: + await self.stream_handle(request, response_headers, response_trailers, writer) def is_stream(self) -> bool: """Determine if this handler is a stream handler. @@ -346,7 +343,6 @@ async def stream_handle( ConnectError: If an internal error occurs during the handling of the stream. """ - self._is_stream_handler = True conn = await self.protocol_handler.conn(request, response_headers, response_trailers, writer, is_streaming=True) if conn is None: return @@ -359,6 +355,7 @@ async def stream_handle( await self.implementation(conn, timeout_ms) else: await self.implementation(conn, None) + except Exception as e: error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) await conn.send_error(error) @@ -382,7 +379,6 @@ async def unary_handle( None """ - self._is_stream_handler = False conn = await self.protocol_handler.conn(request, response_headers, response_trailers, writer) if conn is None: return @@ -395,6 +391,7 @@ async def unary_handle( await self.implementation(conn, timeout_ms) else: await self.implementation(conn, None) + except Exception as e: error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) @@ -513,7 +510,7 @@ def __init__( procedure: str, stream: StreamFunc[T_Request, T_Response], input: type[T_Request], - output: type[T_Response], # noqa: ARG002 + output: type[T_Response], options: ConnectOptions | None = None, ) -> None: """Initialize a new handler instance. @@ -596,7 +593,7 @@ def __init__( procedure: str, stream: StreamFunc[T_Request, T_Response], input: type[T_Request], - output: type[T_Response], # noqa: ARG002 + output: type[T_Response], options: ConnectOptions | None = None, ) -> None: """Initialize a new instance of the handler. @@ -696,7 +693,7 @@ def __init__( procedure: str, stream: StreamFunc[T_Request, T_Response], input: type[T_Request], - output: type[T_Response], # noqa: ARG002 + output: type[T_Response], options: ConnectOptions | None = None, ) -> None: """Initialize a bidirectional streaming handler. diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index c083723..7de43b5 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -1538,7 +1538,7 @@ def peer(self) -> Peer: """ return self._peer - async def receive_message(self, message: Any) -> AsyncIterator[Any]: + async def _receive_message(self, message: Any) -> AsyncIterator[Any]: """Asynchronously receives a message and yields unmarshaled objects. This method unmarshals the received message and yields each @@ -1572,7 +1572,7 @@ def receive(self, message: Any) -> AsyncContentStream[Any]: """ return AsyncContentStream( - iterable=self.receive_message(message), + iterable=self._receive_message(message), stream_type=self.spec.stream_type, ) @@ -2084,22 +2084,30 @@ def peer(self) -> Peer: """ return self._peer - def receive(self, message: Any) -> AsyncIterator[Any]: - """Asynchronously receives a message, unmarshals it, and returns the resulting object. + async def _receive_message(self, message: Any) -> AsyncIterator[Any]: + """Asynchronously receives and unmarshals a message, yielding the resulting object. Args: message (Any): The message to be unmarshaled. - Returns: - AsyncIterator[Any]: An async iterator yielding the unmarshaled object. + Yields: + Any: The unmarshaled object. """ + obj = await self.unmarshaler.unmarshal(message) + yield obj - async def _receive() -> AsyncIterator[Any]: - obj = await self.unmarshaler.unmarshal(message) - yield obj + def receive(self, message: Any) -> AsyncIterator[Any]: + """Receives a message and returns an asynchronous iterator over the processed message. - return _receive() + Args: + message (Any): The message to be received and processed. + + Returns: + AsyncIterator[Any]: An asynchronous iterator yielding processed message(s). + + """ + return self._receive_message(message) @property def request_headers(self) -> Headers: @@ -2316,6 +2324,18 @@ def event_hooks(self, event_hooks: dict[str, list[EventHook]]) -> None: "response": list(event_hooks.get("response", [])), } + async def aclose(self) -> None: + """Asynchronously closes the connection or releases any resources held by the object. + + This method should be called when the object is no longer needed to ensure proper cleanup. + Currently, this implementation does not perform any actions, but it can be overridden in subclasses. + + Returns: + None + + """ + return + def connect_validate_unary_response_content_type( request_codec_name: str, diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 3c05822..ef106bd 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -1,3 +1,5 @@ +"""Provaides classes and functions for handling gRPC protocol.""" + import base64 import re import urllib.parse @@ -58,12 +60,38 @@ class ProtocolGPRC(Protocol): - web: bool + """ProtocolGPRC is a protocol implementation for handling gRPC and gRPC-Web requests. + + Attributes: + web (bool): Indicates whether to use gRPC-Web (True) or standard gRPC (False). + + """ def __init__(self, web: bool) -> None: + """Initialize the instance. + + Args: + web (bool): Indicates whether the instance is for web usage. + + """ self.web = web def handler(self, params: ProtocolHandlerParams) -> ProtocolHandler: + """Create and returns a GRPCHandler instance configured with appropriate content types based on the provided parameters. + + Args: + params (ProtocolHandlerParams): The parameters containing codec information and other handler configuration. + + Returns: + ProtocolHandler: An instance of GRPCHandler initialized with the correct content types for gRPC or gRPC-Web. + + Behavior: + - Determines the default and prefix content types based on whether gRPC-Web is enabled. + - Constructs a list of supported content types from the available codecs. + - Adds the bare content type if the PROTO codec is present. + - Returns a GRPCHandler with the computed content types. + + """ bare, prefix = GRPC_CONTENT_TYPE_DEFAULT, GRPC_CONTENT_TYPE_PREFIX if self.web: bare, prefix = GRPC_WEB_CONTENT_TYPE_DEFAULT, GRPC_WEB_CONTENT_TYPE_PREFIX @@ -91,23 +119,71 @@ def client(self, params: ProtocolClientParams) -> ProtocolClient: class GRPCHandler(ProtocolHandler): + """GRPCHandler is a protocol handler for gRPC and gRPC-Web requests. + + This class implements the ProtocolHandler interface to handle gRPC protocol requests, + including negotiation of compression, codec selection, and connection management for + both standard gRPC and gRPC-Web. It supports content type negotiation, payload handling, + and manages the lifecycle of a gRPC connection, including streaming and non-streaming + requests. + + Attributes: + params (ProtocolHandlerParams): Configuration parameters for the handler, including codecs and compressions. + web (bool): Indicates if the handler is for gRPC-Web. + accept (list[str]): List of accepted content types. + + """ + params: ProtocolHandlerParams web: bool accept: list[str] def __init__(self, params: ProtocolHandlerParams, web: bool, accept: list[str]) -> None: + """Initialize the ProtocolHandler with the given parameters. + + Args: + params (ProtocolHandlerParams): The parameters required for the protocol handler. + web (bool): Indicates whether the handler is for web usage. + accept (list[str]): A list of accepted content types. + + Returns: + None + + """ self.params = params self.web = web self.accept = accept @property def methods(self) -> list[HTTPMethod]: + """Returns a list of allowed HTTP methods for gRPC protocol. + + Returns: + list[HTTPMethod]: A list containing the HTTP methods permitted for gRPC communication. + + """ return GRPC_ALLOWED_METHODS def content_types(self) -> list[str]: + """Return a list of accepted content types. + + Returns: + list[str]: A list of MIME types that are accepted. + + """ return self.accept def can_handle_payload(self, _: Request, content_type: str) -> bool: + """Determine if the given content type is supported by this handler. + + Args: + _ (Request): The request object (unused). + content_type (str): The MIME type of the payload to check. + + Returns: + bool: True if the content type is accepted, False otherwise. + + """ return content_type in self.accept async def conn( @@ -153,7 +229,6 @@ async def conn( query=request.query_params, ) - # Create a single unified handler class with streaming flag conn = GRPCHandlerConn( writer=writer, spec=self.params.spec, @@ -185,6 +260,22 @@ async def conn( class GRPCMarshaler(EnvelopeWriter): + """GRPCMarshaler is responsible for marshaling messages into the gRPC wire format. + + Args: + web (bool): Indicates whether to use the gRPC-web protocol. + codec (Codec | None): The codec used for encoding/decoding messages. + compression (Compression | None): The compression algorithm to use, if any. + compress_min_bytes (int): Minimum message size in bytes before compression is applied. + send_max_bytes (int): Maximum allowed size of a message to send. + + Methods: + marshal(messages: AsyncIterable[bytes]) -> AsyncIterator[bytes]: + Asynchronously marshals a stream of message bytes into the gRPC wire format. + Yields marshaled message bytes ready for transmission. + + """ + web: bool def __init__( @@ -195,15 +286,52 @@ def __init__( compress_min_bytes: int, send_max_bytes: int, ) -> None: + """Initialize the protocol with the specified configuration. + + Args: + web (bool): Indicates whether the protocol is used in a web context. + codec (Codec | None): The codec to use for encoding/decoding messages, or None for default. + compression (Compression | None): The compression algorithm to use, or None for no compression. + compress_min_bytes (int): The minimum number of bytes before compression is applied. + send_max_bytes (int): The maximum number of bytes allowed to send in a single message. + + Returns: + None + + """ super().__init__(codec, compression, compress_min_bytes, send_max_bytes) self.web = web async def marshal(self, messages: AsyncIterable[bytes]) -> AsyncIterator[bytes]: + """Asynchronously marshals a stream of byte messages. + + Args: + messages (AsyncIterable[bytes]): An asynchronous iterable of byte messages to be marshaled. + + Yields: + AsyncIterator[bytes]: An asynchronous iterator yielding marshaled byte messages. + + """ async for message in self._marshal(messages): yield message class GRPCUnmarshaler(EnvelopeReader): + """GRPCUnmarshaler is a specialized EnvelopeReader for handling gRPC message unmarshaling. + + Args: + codec (Codec | None): The codec used for decoding messages. + read_max_bytes (int): The maximum number of bytes to read from the stream. + stream (AsyncIterable[bytes] | None, optional): The asynchronous byte stream to read messages from. + compression (Compression | None, optional): Compression algorithm to use for decompressing messages. + + Methods: + async unmarshal(message: Any) -> AsyncIterator[Any]: + Asynchronously unmarshals the given message, yielding each decoded object. + Iterates over the results of the internal _unmarshal method, yielding only the object part of each tuple. + + """ + def __init__( self, codec: Codec | None, @@ -211,14 +339,47 @@ def __init__( stream: AsyncIterable[bytes] | None = None, compression: Compression | None = None, ) -> None: + """Initialize the protocol gRPC handler. + + Args: + codec (Codec | None): The codec to use for encoding/decoding messages. Can be None. + read_max_bytes (int): The maximum number of bytes to read from the stream. + stream (AsyncIterable[bytes] | None, optional): An asynchronous iterable stream of bytes. Defaults to None. + compression (Compression | None, optional): The compression method to use. Defaults to None. + + """ super().__init__(codec, read_max_bytes, stream, compression) async def unmarshal(self, message: Any) -> AsyncIterator[Any]: + """Asynchronously unmarshals a given message and yields each resulting object. + + Args: + message (Any): The message to be unmarshaled. + + Yields: + Any: Each object obtained from unmarshaling the message. + + """ async for obj, _ in self._unmarshal(message): yield obj class GRPCHandlerConn(StreamingHandlerConn): + """GRPCHandlerConn is a handler class for managing gRPC protocol connections within a streaming server context. + + This class encapsulates the logic for handling gRPC requests and responses, including marshaling and unmarshaling messages, + managing request and response headers/trailers, handling timeouts, and enforcing protocol-specific constraints for unary and streaming operations. + + Attributes: + _spec (Spec): The specification object describing the protocol or service. + _peer (Peer): The peer information for the current connection. + _request_headers (Headers): The headers received with the request. + _response_headers (Headers): The headers to include in the response. + _response_trailers (Headers): The trailers to include in the response. + _is_streaming (bool): Indicates if the connection is streaming. + + """ + _spec: Spec _peer: Peer writer: ServerResponseWriter @@ -241,6 +402,20 @@ def __init__( response_trailers: Headers | None = None, is_streaming: bool = False, ) -> None: + """Initialize a new instance of the class. + + Args: + writer (ServerResponseWriter): The writer used to send responses to the client. + spec (Spec): The specification object describing the protocol or service. + peer (Peer): The peer information for the current connection. + marshaler (GRPCMarshaler): The marshaler used to serialize response messages. + unmarshaler (GRPCUnmarshaler): The unmarshaler used to deserialize request messages. + request_headers (Headers): The headers received with the request. + response_headers (Headers): The headers to include in the response. + response_trailers (Headers | None, optional): The trailers to include in the response. Defaults to None. + is_streaming (bool, optional): Indicates if the connection is streaming. Defaults to False. + + """ self.writer = writer self._spec = spec self._peer = peer @@ -252,6 +427,19 @@ def __init__( self._is_streaming = is_streaming def parse_timeout(self) -> float | None: + """Parse the gRPC timeout value from the request headers and returns it as seconds. + + Returns: + float | None: The timeout value in seconds if present and valid, otherwise None. + + Raises: + ConnectError: If the timeout value is present but invalid or too long. + + Notes: + - The timeout is extracted from the gRPC header and must match the expected format. + - If the timeout unit is hours and exceeds the maximum allowed, None is returned. + + """ timeout = self._request_headers.get(GRPC_HEADER_TIMEOUT) if not timeout: return None @@ -274,10 +462,22 @@ def parse_timeout(self) -> float | None: @property def spec(self) -> Spec: + """Returns the specification object associated with this instance. + + Returns: + Spec: The specification object. + + """ return self._spec @property def peer(self) -> Peer: + """Returns the associated Peer object. + + Returns: + Peer: The peer instance associated with this object. + + """ return self._peer def receive(self, message: Any) -> AsyncContentStream[Any]: @@ -295,6 +495,12 @@ def receive(self, message: Any) -> AsyncContentStream[Any]: @property def request_headers(self) -> Headers: + """Returns the headers associated with the current request. + + Returns: + Headers: The headers of the request. + + """ return self._request_headers async def send(self, messages: AsyncIterable[Any]) -> None: @@ -333,10 +539,25 @@ async def send(self, messages: AsyncIterable[Any]) -> None: @property def response_headers(self) -> Headers: + """Returns the response headers associated with the current request. + + Returns: + Headers: The headers returned in the response. + + """ return self._response_headers @property def response_trailers(self) -> Headers: + """Returns the response trailers as headers. + + Response trailers are additional metadata sent by the server after the response body, + typically used in gRPC and HTTP/2 protocols. + + Returns: + Headers: The response trailers associated with the current response. + + """ return self._response_trailers async def marshal_with_error_handling(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: @@ -359,6 +580,18 @@ async def marshal_with_error_handling(self, messages: AsyncIterable[Any]) -> Asy grpc_error_to_trailer(self.response_trailers, error) async def send_error(self, error: ConnectError) -> None: + """Send an error response over gRPC by converting the provided ConnectError into gRPC trailers. + + Args: + error (ConnectError): The error to be sent as a gRPC trailer. + + Returns: + None + + This method updates the response trailers with the error information and writes a streaming response + with the appropriate headers and trailers to the client. + + """ grpc_error_to_trailer(self.response_trailers, error) await self.writer.write( @@ -369,6 +602,18 @@ async def send_error(self, error: ConnectError) -> None: def grpc_codec_from_content_type(web: bool, content_type: str) -> str: + """Determine the gRPC codec name from the given content type string. + + Args: + web (bool): Indicates whether the request is a gRPC-web request. + content_type (str): The content type string to parse. + + Returns: + str: The codec name extracted from the content type. If the content type matches the default gRPC or gRPC-web content type, + returns the default codec name. Otherwise, extracts and returns the codec name from the content type prefix, or returns + the original content type if no known prefix is found. + + """ if (not web and content_type == GRPC_CONTENT_TYPE_DEFAULT) or ( web and content_type == GRPC_WEB_CONTENT_TYPE_DEFAULT ): @@ -383,6 +628,21 @@ def grpc_codec_from_content_type(web: bool, content_type: str) -> str: def grpc_error_to_trailer(trailer: Headers, error: ConnectError | None) -> None: + """Convert a ConnectError to gRPC trailer headers. + + Args: + trailer (Headers): The trailer headers dictionary to update with gRPC error information. + error (ConnectError | None): The error to convert. If None, indicates success. + + Side Effects: + Modifies the `trailer` dictionary in-place to include gRPC status, message, and optional details. + + Notes: + - If `error` is None, sets the gRPC status header to "0" (OK). + - If `ConnectError.wire_error` is False, updates the trailer with error metadata excluding protocol headers. + - Serializes error details using protobuf if present, encoding them in base64 for the trailer. + + """ if error is None: trailer[GRPC_HEADER_STATUS] = "0" return diff --git a/src/connect/response_trailer.py b/src/connect/response_trailer.py index 66787a1..a566ef7 100644 --- a/src/connect/response_trailer.py +++ b/src/connect/response_trailer.py @@ -1,4 +1,4 @@ -from __future__ import annotations +"""Streaming HTTP response with support for trailers.""" import typing from functools import partial @@ -17,6 +17,21 @@ class StreamingResponseWithTrailers(Response): + """A streaming HTTP response class that supports HTTP trailers. + + This class extends the standard response to allow sending HTTP trailers + at the end of a streamed response body, if supported by the ASGI server. + + Attributes: + body_iterator (AsyncContentStream): An asynchronous iterator over the response body content. + status_code (int): HTTP status code for the response. + media_type (str | None): The media type of the response. + background (BackgroundTask | None): Optional background task to run after response is sent. + headers (Mapping[str, str]): HTTP headers for the response. + _trailers (Mapping[str, str] | None): HTTP trailers to send after the response body. + + """ + body_iterator: AsyncContentStream def __init__( @@ -29,6 +44,21 @@ def __init__( media_type: str | None = None, background: BackgroundTask | None = None, ) -> None: + """Initialize a response object with optional HTTP trailers. + + Args: + content (ContentStream): The response body content, which can be an async iterable or a regular iterable. + status_code (int, optional): HTTP status code for the response. Defaults to 200. + headers (typing.Mapping[str, str] | None, optional): HTTP headers to include in the response. Defaults to None. + trailers (typing.Mapping[str, str] | None, optional): HTTP trailers to include in the response. Defaults to None. + media_type (str | None, optional): The media type of the response. If None, uses the default media type. Defaults to None. + background (BackgroundTask | None, optional): A background task to run after the response is sent. Defaults to None. + + Notes: + - If `content` is not an async iterable, it will be wrapped to run in a thread pool. + - If trailers are provided, their names will be added to the "Trailer" header. + + """ if isinstance(content, typing.AsyncIterable): self.body_iterator = content else: @@ -69,6 +99,24 @@ async def _stream_response(self, send: Send, trailers_supported: bool) -> None: }) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Handle the ASGI call interface for streaming HTTP responses with optional support for HTTP trailers. + + This method determines the ASGI spec version and whether HTTP response trailers are supported. + For ASGI spec version >= 2.4, it streams the response and handles client disconnects. + For earlier versions, it concurrently streams the response and listens for client disconnects, + cancelling the response stream if a disconnect is detected. + + After sending the response, if a background task is provided, it is awaited. + + Args: + scope (Scope): The ASGI connection scope. + receive (Receive): Awaitable callable to receive ASGI messages. + send (Send): Awaitable callable to send ASGI messages. + + Raises: + ClientDisconnect: If the client disconnects during response streaming. + + """ spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) trailers_supported = "http.response.trailers" in scope.get("extensions", {})