diff --git a/conformance/run-testcase.txt b/conformance/run-testcase.txt index b42c6a1..26b35a5 100644 --- a/conformance/run-testcase.txt +++ b/conformance/run-testcase.txt @@ -1 +1 @@ -gRPC Unexpected Requests/HTTPVersion:2/TLS:false/unary/no-request +gRPC Unexpected Requests/HTTPVersion:2/TLS:true/unary/multiple-requests diff --git a/src/connect/connect.py b/src/connect/connect.py index 998f8f3..a5ede48 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Mapping from enum import Enum from http import HTTPMethod -from typing import Any, Protocol, cast +from typing import Any, cast from pydantic import BaseModel @@ -318,123 +318,114 @@ async def aclose(self) -> None: await aclose() -class UnaryHandlerConn(abc.ABC): - """Abstract base class for a streaming handler connection. - - This class defines the interface for handling streaming connections, including - methods for specifying the connection, handling peer communication, receiving - and sending messages, and managing request and response headers and trailers. - - """ - - @abc.abstractmethod - def parse_timeout(self) -> float | None: - """Parse the timeout value.""" - raise NotImplementedError() - - @property - @abc.abstractmethod - def spec(self) -> Spec: - """Return the specification details. +class AsyncContentStream[T](AsyncIterable[T]): + """AsyncContentStream is a generic asynchronous stream wrapper for async iterables, providing validation and iteration utilities based on stream type. - Returns: - Spec: The specification details. + Type Parameters: + T: The type of elements yielded by the asynchronous iterable. - """ - raise NotImplementedError() + iterable (AsyncIterable[T]): The asynchronous iterable to wrap. + stream_type (StreamType): The type of stream (e.g., Unary, ServerStream) that determines validation behavior. - @property - @abc.abstractmethod - def peer(self) -> Peer: - """Establish a connection to a peer in the network. + Attributes: + _iterable (AsyncIterable[T]): The underlying asynchronous iterable. + stream_type (StreamType): The type of stream this instance represents. - Returns: - Any: The result of the connection attempt. The exact type and structure - of the return value will depend on the implementation details. + """ - """ - raise NotImplementedError() + def __init__(self, iterable: AsyncIterable[T], stream_type: StreamType) -> None: + """Initialize a stream wrapper for an async iterable. - @abc.abstractmethod - async def receive(self, message: Any) -> Any: - """Receives a message and processes it. + This constructor stores the provided async iterable and its corresponding + stream type for later processing. Args: - message (Any): The message to be received and processed. + iterable: An asynchronous iterable containing elements of type T. + stream_type: The type of stream this iterable represents. Returns: - Any: The result of processing the message. + None """ - raise NotImplementedError() + self._iterable = iterable + self.stream_type = stream_type - @property - @abc.abstractmethod - def request_headers(self) -> Headers: - """Generate and return the request headers. + def _needs_single_content_validation(self) -> bool: + """Determine if single message validation is required based on the stream type. Returns: - Any: The request headers. + bool: True if the stream type is Unary or ServerStream, indicating that single message validation is needed; False otherwise. """ - raise NotImplementedError() + return self.stream_type == StreamType.Unary or self.stream_type == StreamType.ServerStream - @abc.abstractmethod - async def send(self, message: Any) -> None: - """Send a message. + async def _validate_single_content_stream(self, iterable: AsyncIterable[T]) -> AsyncIterator[T]: + """Validate that the given asynchronous iterable yields exactly one item. - This method should be implemented by subclasses to define how the message - should be sent. + Iterates over the provided async iterable and ensures that it produces a single item. + If more than one item is yielded, raises a ConnectError indicating a protocol error. + If no items are yielded, also raises a ConnectError indicating a protocol error. Args: - message (Any): The message to be sent. + iterable (AsyncIterable[T]): The asynchronous iterable to validate. + + Yields: + T: The single item from the iterable. Raises: - NotImplementedError: If the method is not implemented by a subclass. + ConnectError: If the iterable yields zero or more than one item. """ - raise NotImplementedError() - - @property - @abc.abstractmethod - def response_headers(self) -> Headers: - """Retrieve the response headers. + count = 0 + async for item in iterable: + if count > 0: + raise ConnectError("protocol error: expected only one message, but got multiple", Code.UNIMPLEMENTED) - Returns: - Any: The response headers. + yield item + count += 1 - """ - raise NotImplementedError() + if count == 0: + raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) - @property - @abc.abstractmethod - def response_trailers(self) -> Headers: - """Handle response trailers. + async def __aiter__(self) -> AsyncIterator[T]: + """Asynchronously iterates over the underlying iterable. - This method is intended to be overridden in subclasses to provide - specific functionality for processing response trailers. + If single message validation is required, wraps the iterable with a validation step. + Otherwise, yields items directly from the iterable. - Returns: - Any: The return type is not specified as this is a placeholder method. + Yields: + T: Items from the underlying asynchronous iterable. """ - raise NotImplementedError() + if self._needs_single_content_validation(): + async for item in self._validate_single_content_stream(self._iterable): + yield item + else: + async for item in self._iterable: + yield item - @abc.abstractmethod - async def send_error(self, error: ConnectError) -> None: - """Send an error message. + async def ensure_single(self) -> T: + """Asynchronously ensures that exactly one message is present in the iterable. - This method should be implemented to handle the sending of error messages - in a specific manner defined by the subclass. + Iterates over the provided asynchronous iterable and retrieves a single message. + Raises a ConnectError if no messages are found. If multiple messages are present, + only the last one will be returned (potential protocol error). - Args: - error (ConnectError): The error to be sent. + Returns: + T: The single message retrieved from the iterable. Raises: - NotImplementedError: If the method is not implemented by the subclass. + ConnectError: If no messages are found in the iterable. """ - raise NotImplementedError() + message = None + async for item in self._validate_single_content_stream(self._iterable): + message = item + + if message is None: + raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) + + return message class StreamingHandlerConn(abc.ABC): @@ -475,14 +466,17 @@ def peer(self) -> Peer: raise NotImplementedError() @abc.abstractmethod - def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message and processes it. + def receive(self, message: Any) -> AsyncContentStream[Any]: + """Receives a message and returns an asynchronous content stream. Args: - message (Any): The message to be received and processed. + message (Any): The message to be processed. Returns: - Any: The result of processing the message. + AsyncContentStream[Any]: An asynchronous stream of content resulting from processing the message. + + Raises: + NotImplementedError: This method should be implemented by subclasses. """ raise NotImplementedError() @@ -493,7 +487,7 @@ def request_headers(self) -> Headers: """Generate and return the request headers. Returns: - Any: The request headers. + Headers: The request headers. """ raise NotImplementedError() @@ -504,6 +498,7 @@ async def send(self, messages: AsyncIterable[Any]) -> None: Args: messages (AsyncIterable[Any]): The messages to be sent. + For unary operations, this should be an iterable with a single item. Raises: NotImplementedError: This method should be implemented by subclasses. @@ -517,7 +512,7 @@ def response_headers(self) -> Headers: """Retrieve the response headers. Returns: - Any: The response headers. + Headers: The response headers. """ raise NotImplementedError() @@ -531,7 +526,7 @@ def response_trailers(self) -> Headers: specific functionality for processing response trailers. Returns: - Any: The return type is not specified as this is a placeholder method. + Headers: The response trailers. """ raise NotImplementedError() @@ -569,7 +564,7 @@ def peer(self) -> Peer: raise NotImplementedError() @abc.abstractmethod - async def receive(self, message: Any) -> Any: + def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message and processes it.""" raise NotImplementedError() @@ -601,6 +596,11 @@ def on_request_send(self, fn: Callable[..., Any]) -> None: """Handle the request send event.""" raise NotImplementedError() + @abc.abstractmethod + async def aclose(self) -> None: + """Asynchronously close the connection.""" + raise NotImplementedError() + class StreamingClientConn: """Abstract base class for a streaming client connection.""" @@ -658,55 +658,19 @@ async def aclose(self) -> None: raise NotImplementedError() -class ReceiveConn(Protocol): - """A protocol that defines the methods required for receiving connections.""" - - @property - @abc.abstractmethod - def spec(self) -> Spec: - """Retrieve the specification for the current object. - - This method should be implemented by subclasses to return an instance - of the `Spec` class that defines the specification for the object. - - Raises: - NotImplementedError: If the method is not implemented by a subclass. - - Returns: - Spec: The specification for the current object. - - """ - raise NotImplementedError() - - @abc.abstractmethod - async def receive(self, message: Any) -> Any: - """Receives a message and processes it. - - Args: - message (Any): The message to be received and processed. - - Returns: - Any: The result of processing the message. - - Raises: - NotImplementedError: This method should be implemented by subclasses. - - """ - raise NotImplementedError() - - -async def receive_unary_request[T](conn: UnaryHandlerConn, t: type[T]) -> UnaryRequest[T]: +async def receive_unary_request[T](conn: StreamingHandlerConn, t: type[T]) -> UnaryRequest[T]: """Receives a unary request from the given connection and returns a UnaryRequest object. Args: - conn (UnaryHandlerConn): The connection from which to receive the unary request. + conn (StreamingHandlerConn): The connection from which to receive the unary request. t (type[T]): The type of the message to be received. Returns: UnaryRequest[T]: A UnaryRequest object containing the received message. """ - message = await receive_unary_message(conn, t) + stream = conn.receive(t) + message = await stream.ensure_single() method = HTTPMethod.POST get_http_method = get_callable_attribute(conn, "get_http_method") @@ -734,8 +698,10 @@ async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> S peer information, request headers, and HTTP method. """ + stream = conn.receive(t) + return StreamRequest( - messages=receive_stream_message(conn, t, conn.spec), + messages=stream, spec=conn.spec, peer=conn.peer, headers=conn.request_headers, @@ -743,41 +709,6 @@ async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> S ) -async def receive_stream_message[T](conn: StreamingHandlerConn, t: type[T], spec: Spec) -> AsyncIterator[T]: - """Asynchronously receives and yields messages from a streaming connection. - - This function listens to a streaming connection and yields messages of the specified type. - - Args: - conn (StreamingHandlerConn): The streaming connection handler. - t (type[T]): The type of messages to receive. - spec (Spec): The specification for the request. - - Yields: - AsyncIterator[T]: An asynchronous iterator of messages of type T. - - """ - if spec.stream_type == StreamType.ServerStream: - count = 0 - async for message in conn.receive(t): - count += 1 - if count > 1: - raise ConnectError( - f"received extra input message for {conn.spec.procedure} method", - Code.UNIMPLEMENTED, - ) - yield message - - if count == 0: - raise ConnectError( - f"missing input message for {conn.spec.procedure} method", - Code.UNIMPLEMENTED, - ) - else: - async for message in conn.receive(t): - yield message - - async def recieve_unary_response[T](conn: UnaryClientConn, t: type[T]) -> UnaryResponse[T]: """Receive a unary response from a streaming client connection. @@ -868,16 +799,22 @@ async def recieve_stream_response[T]( return StreamResponse(receive_stream, conn.response_headers, conn.response_trailers) -async def receive_unary_message[T](conn: ReceiveConn, t: type[T]) -> T: - """Receive a unary message from the given connection. +async def receive_unary_message[T](conn: UnaryClientConn, t: type[T]) -> T: + """Asynchronously receives a single unary message from the given connection. Args: - conn (ReceiveConn): The connection object to receive the message from. - t (type[T]): The type of the message to be received. + conn (UnaryClientConn): The unary client connection to receive the message from. + t (type[T]): The expected type of the message to be received. Returns: T: The received message of type T. + Raises: + Exception: If receiving the message fails or more than one message is received. + + Note: + This function ensures that exactly one message is received from the connection. + """ - message = await conn.receive(t) - return message + single_message = await _receive_exactly_one(conn.receive(t), conn.aclose) + return single_message diff --git a/src/connect/envelope.py b/src/connect/envelope.py index 0a3b1ba..9d635b7 100644 --- a/src/connect/envelope.py +++ b/src/connect/envelope.py @@ -126,6 +126,27 @@ def is_set(self, flag: EnvelopeFlags) -> bool: class EnvelopeWriter: + """EnvelopeWriter is responsible for marshaling messages, optionally compressing them, and writing them into envelopes for transmission. + + Attributes: + codec (Codec | None): The codec used for encoding and decoding messages. + send_max_bytes (int): The maximum number of bytes allowed per message. + compression (Compression | None): The compression method to use, or None for no compression. + + Methods: + __init__(codec, compression, compress_min_bytes, send_max_bytes): + Initializes the EnvelopeWriter with the specified codec, compression, and size constraints. + + async _marshal(messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: + Asynchronously marshals and optionally compresses messages from an async iterable, yielding encoded envelope bytes. + Raises ConnectError if marshaling fails or message size exceeds the allowed limit. + + write_envelope(data: bytes, flags: EnvelopeFlags) -> Envelope: + Writes an envelope, optionally compressing its data if conditions are met, and updates envelope flags accordingly. + Raises ConnectError if the (compressed) message size exceeds the allowed maximum. + + """ + codec: Codec | None compress_min_bytes: int send_max_bytes: int @@ -174,16 +195,22 @@ async def _marshal(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: yield env.encode() def write_envelope(self, data: bytes, flags: EnvelopeFlags) -> Envelope: - """Write an envelope, optionally compressing its data if certain conditions are met. + """Write an envelope containing the provided data, applying compression if required. Args: - env (Envelope): The envelope to be written. + data (bytes): The message payload to be written into the envelope. + flags (EnvelopeFlags): Flags indicating envelope properties, such as compression. Returns: - Envelope: The envelope with possibly compressed data and updated flags. + Envelope: An envelope object containing the (optionally compressed) data and updated flags. Raises: - ConnectError: If the size of the envelope data exceeds the maximum allowed size. + ConnectError: If the (compressed or uncompressed) data size exceeds the configured send_max_bytes limit. + + Notes: + - Compression is applied only if the flags do not already indicate compression, + compression is enabled, and the data size exceeds the minimum threshold. + - The flags are updated to include the compressed flag if compression is performed. """ if flags in EnvelopeFlags.compressed or self.compression is None or len(data) < self.compress_min_bytes: diff --git a/src/connect/handler.py b/src/connect/handler.py index 4a53f95..61efdb2 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -1,11 +1,10 @@ """Module provides handler configurations and implementations for unary procedures and stream types.""" import asyncio -import inspect import logging from collections.abc import Awaitable, Callable from http import HTTPMethod, HTTPStatus -from typing import Any, TypeGuard +from typing import Any import anyio from starlette.responses import PlainTextResponse @@ -19,7 +18,6 @@ StreamRequest, StreamResponse, StreamType, - UnaryHandlerConn, UnaryRequest, UnaryResponse, receive_stream_request, @@ -46,6 +44,7 @@ from connect.protocol_grpc import ProtocolGPRC from connect.request import Request from connect.response import Response +from connect.utils import aiterate from connect.writer import ServerResponseWriter logging.basicConfig(level=logging.INFO) @@ -148,16 +147,11 @@ def create_protocol_handlers(config: HandlerConfig) -> list[ProtocolHandler]: return handlers -UnaryImplementationFunc = Callable[[UnaryHandlerConn, float | None], Awaitable[None]] -StreamImplementationFunc = Callable[[StreamingHandlerConn, float | None], Awaitable[None]] - - class Handler: """A class to handle incoming HTTP requests and generate appropriate HTTP responses. Attributes: procedure (str): The procedure name. - implementation (UnaryImplementationFunc | StreamImplementationFunc): The implementation function for handling requests. protocol_handlers (dict[HTTPMethod, list[ProtocolHandler]]): A dictionary mapping HTTP methods to protocol handlers. allow_methods (str): Allowed HTTP methods. accept_post (str): Accepted content types for POST requests. @@ -166,7 +160,6 @@ class Handler: """ procedure: str - implementation: UnaryImplementationFunc | StreamImplementationFunc protocol_handlers: dict[HTTPMethod, list[ProtocolHandler]] allow_methods: str accept_post: str @@ -175,7 +168,6 @@ class Handler: def __init__( self, procedure: str, - implementation: UnaryImplementationFunc | StreamImplementationFunc, protocol_handlers: dict[HTTPMethod, list[ProtocolHandler]], allow_methods: str, accept_post: str, @@ -184,18 +176,31 @@ def __init__( Args: procedure (str): The name of the procedure. - implementation (UnaryImplementationFunc | StreamImplementationFunc): The function implementing the procedure. protocol_handlers (dict[HTTPMethod, list[ProtocolHandler]]): A dictionary mapping HTTP methods to protocol handlers. allow_methods (str): A string specifying allowed HTTP methods. accept_post (str): A string specifying if POST method is accepted. """ self.procedure = procedure - self.implementation = implementation self.protocol_handlers = protocol_handlers self.allow_methods = allow_methods self.accept_post = accept_post + async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: + """Handle the implementation of the request processing. + + This method should be overridden by subclasses. + + Args: + conn: The connection handler + timeout: Optional timeout in milliseconds + + Raises: + NotImplementedError: If not implemented by a subclass + + """ + raise NotImplementedError("Implementation must be provided by subclass") + async def handle(self, request: Request) -> Response: """Handle an incoming HTTP request and return an HTTP response. @@ -255,6 +260,7 @@ async def handle(self, request: Request) -> Response: writer = ServerResponseWriter() + # Create tasks for handling the request and receiving responses main_task = asyncio.create_task(self._handle(request, response_headers, response_trailers, writer)) writer_task = asyncio.create_task(writer.receive()) @@ -291,40 +297,32 @@ async def handle(self, request: Request) -> Response: async def _handle( self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter ) -> None: - if self.is_stream(self.implementation): - await self.stream_handle(request, response_headers, response_trailers, writer) - else: + if getattr(self, "stream_type", StreamType.Unary) == StreamType.Unary: await self.unary_handle(request, response_headers, response_trailers, writer) + else: + await self.stream_handle(request, response_headers, response_trailers, writer) - def is_stream( - self, impl: UnaryImplementationFunc | StreamImplementationFunc - ) -> TypeGuard[StreamImplementationFunc]: - """Determine if the given implementation function is a stream implementation. - - Args: - impl (UnaryImplementationFunc | StreamImplementationFunc): The implementation function to check. + def is_stream(self) -> bool: + """Determine if this handler is a stream handler. Returns: - TypeGuard[StreamImplementationFunc]: True if the implementation function is a stream implementation, False otherwise. + bool: True if this is a stream handler, False otherwise. """ - signature = inspect.signature(impl) - parameters = signature.parameters - return len(parameters) == 2 and next(iter(parameters.values())).annotation == StreamingHandlerConn + # Since we've consolidated to a single connection type, use a sentinel value + is_stream_handler = getattr(self, "_is_stream_handler", False) + return is_stream_handler - def is_unary(self, impl: UnaryImplementationFunc | StreamImplementationFunc) -> TypeGuard[UnaryImplementationFunc]: - """Determine if the given implementation function is a unary implementation. - - Args: - impl (UnaryImplementationFunc | StreamImplementationFunc): The implementation function to check. + def is_unary(self) -> bool: + """Determine if this handler is a unary handler. Returns: - TypeGuard[UnaryImplementationFunc]: True if the implementation function is a unary implementation, False otherwise. + bool: True if this is a unary handler, False otherwise. """ - signature = inspect.signature(impl) - parameters = signature.parameters - return len(parameters) == 2 and next(iter(parameters.values())).annotation == UnaryHandlerConn + # Since we've consolidated to a single connection type, use a sentinel value + is_stream_handler = getattr(self, "_is_stream_handler", False) + return not is_stream_handler async def stream_handle( self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter @@ -341,31 +339,25 @@ async def stream_handle( None Raises: - ValueError: If the function type for the stream handler is invalid. + ValueError: If the implementation method is invalid. ConnectError: If an internal error occurs during the handling of the stream. """ - conn = await self.protocol_handler.stream_conn(request, response_headers, response_trailers, writer) + conn = await self.protocol_handler.conn(request, response_headers, response_trailers, writer, is_streaming=True) if conn is None: return - implementation = self.implementation - if not self.is_stream(implementation): - raise ValueError(f"Invalid function type for stream handler: {implementation}") - try: timeout = conn.parse_timeout() if timeout: timeout_ms = int(timeout * 1000) - with anyio.fail_after(delay=timeout): - await implementation(conn, timeout_ms) + await self.implementation(conn, timeout_ms) else: - await implementation(conn, None) + await self.implementation(conn, None) except Exception as e: error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) - await conn.send_error(error) async def unary_handle( @@ -380,7 +372,7 @@ async def unary_handle( writer (ServerResponseWriter): The writer to send the response. Raises: - ValueError: If the function type is invalid for unary handler. + ValueError: If the implementation method is invalid. ConnectError: If there is an error parsing the timeout or an internal error occurs. Returns: @@ -391,19 +383,14 @@ async def unary_handle( if conn is None: return - implementation = self.implementation - if not self.is_unary(implementation): - raise ValueError(f"Invalid function type for unary handler: {implementation}") - try: timeout = conn.parse_timeout() if timeout: timeout_ms = int(timeout * 1000) - with anyio.fail_after(delay=timeout): - await implementation(conn, timeout_ms) + await self.implementation(conn, timeout_ms) else: - await implementation(conn, None) + await self.implementation(conn, None) except Exception as e: error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) @@ -413,7 +400,6 @@ async def unary_handle( if isinstance(e, NotImplementedError): error = ConnectError("not implemented", Code.UNIMPLEMENTED) - await conn.send_error(error) @@ -438,6 +424,7 @@ class UnaryHandler[T_Request, T_Response](Handler): protocol_handlers: dict[HTTPMethod, list[ProtocolHandler]] allow_methods: str accept_post: str + stream_type: StreamType = StreamType.Unary def __init__( self, @@ -460,31 +447,44 @@ async def _untyped(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Response untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: UnaryHandlerConn, timeout: float | None) -> None: - request = await receive_unary_request(conn, input) - if timeout: - request.timeout = timeout - - response = await untyped(request) - - if not isinstance(response.message, output): - raise ConnectError( - f"expected response of type: {output.__name__}", - Code.INTERNAL, - ) - - conn.response_headers.update(exclude_protocol_headers(response.headers)) - conn.response_trailers.update(exclude_protocol_headers(response.trailers)) - await conn.send(response.message) + self.input = input + self.output = output + self.untyped = untyped super().__init__( procedure=procedure, - implementation=implementation, protocol_handlers=mapped_method_handlers(protocol_handlers), allow_methods=sorted_allow_method_value(protocol_handlers), accept_post=sorted_accept_post_value(protocol_handlers), ) + async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: + """Handle a unary request and send a unary response. + + Args: + conn: The connection handler + timeout: Optional timeout in milliseconds + + Raises: + ConnectError: If the response message type is incorrect + + """ + request = await receive_unary_request(conn, self.input) + if timeout: + request.timeout = timeout + + response = await self.untyped(request) + + if not isinstance(response.message, self.output): + raise ConnectError( + f"expected response of type: {self.output.__name__}", + Code.INTERNAL, + ) + + conn.response_headers.update(exclude_protocol_headers(response.headers)) + conn.response_trailers.update(exclude_protocol_headers(response.trailers)) + await conn.send(aiterate([response.message])) + class ServerStreamHandler[T_Request, T_Response](Handler): """A handler for server-side streaming RPCs. @@ -503,12 +503,14 @@ class ServerStreamHandler[T_Request, T_Response](Handler): """ + stream_type: StreamType = StreamType.ServerStream + def __init__( self, procedure: str, stream: StreamFunc[T_Request, T_Response], input: type[T_Request], - output: type[T_Response], # noqa: ARG002 + output: type[T_Response], options: ConnectOptions | None = None, ) -> None: """Initialize a new handler instance. @@ -527,31 +529,40 @@ def __init__( async def _untyped(request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: response = await stream(request) - return response untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None: - request = await receive_stream_request(conn, input) - if timeout: - request.timeout = timeout - - response = await untyped(request) - - conn.response_headers.update(response.headers) - conn.response_trailers.update(response.trailers) - - await conn.send(response.messages) + self.input = input + self.output = output + self.untyped = untyped super().__init__( procedure=procedure, - implementation=implementation, protocol_handlers=mapped_method_handlers(protocol_handlers), allow_methods=sorted_allow_method_value(protocol_handlers), accept_post=sorted_accept_post_value(protocol_handlers), ) + async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: + """Handle a server stream request and response. + + Args: + conn: The connection handler + timeout: Optional timeout in milliseconds + + """ + request = await receive_stream_request(conn, self.input) + if timeout: + request.timeout = timeout + + response = await self.untyped(request) + + conn.response_headers.update(response.headers) + conn.response_trailers.update(response.trailers) + + await conn.send(response.messages) + class ClientStreamHandler[T_Request, T_Response](Handler): """A handler for client-side streaming RPCs. @@ -575,12 +586,14 @@ class ClientStreamHandler[T_Request, T_Response](Handler): """ + stream_type: StreamType = StreamType.ClientStream + def __init__( self, procedure: str, stream: StreamFunc[T_Request, T_Response], input: type[T_Request], - output: type[T_Response], # noqa: ARG002 + output: type[T_Response], options: ConnectOptions | None = None, ) -> None: """Initialize a new instance of the handler. @@ -602,31 +615,40 @@ def __init__( async def _untyped(request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: response = await stream(request) - return response untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None: - request = await receive_stream_request(conn, input) - if timeout: - request.timeout = timeout - - response = await untyped(request) - - conn.response_headers.update(response.headers) - conn.response_trailers.update(response.trailers) - - await conn.send(response.messages) + self.input = input + self.output = output + self.untyped = untyped super().__init__( procedure=procedure, - implementation=implementation, protocol_handlers=mapped_method_handlers(protocol_handlers), allow_methods=sorted_allow_method_value(protocol_handlers), accept_post=sorted_accept_post_value(protocol_handlers), ) + async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: + """Handle a client stream request and response. + + Args: + conn: The connection handler + timeout: Optional timeout in milliseconds + + """ + request = await receive_stream_request(conn, self.input) + if timeout: + request.timeout = timeout + + response = await self.untyped(request) + + conn.response_headers.update(response.headers) + conn.response_trailers.update(response.trailers) + + await conn.send(response.messages) + class BidiStreamHandler[T_Request, T_Response](Handler): """A handler for bidirectional streaming RPCs in a Connect-based framework. @@ -664,12 +686,14 @@ class BidiStreamHandler[T_Request, T_Response](Handler): """ + stream_type: StreamType = StreamType.BiDiStream + def __init__( self, procedure: str, stream: StreamFunc[T_Request, T_Response], input: type[T_Request], - output: type[T_Response], # noqa: ARG002 + output: type[T_Response], options: ConnectOptions | None = None, ) -> None: """Initialize a bidirectional streaming handler. @@ -696,27 +720,36 @@ def __init__( async def _untyped(request: StreamRequest[T_Request]) -> StreamResponse[T_Response]: response = await stream(request) - return response untyped = apply_interceptors(_untyped, options.interceptors) - async def implementation(conn: StreamingHandlerConn, timeout: float | None) -> None: - request = await receive_stream_request(conn, input) - if timeout: - request.timeout = timeout - - response = await untyped(request) - - conn.response_headers.update(response.headers) - conn.response_trailers.update(response.trailers) - - await conn.send(response.messages) + self.input = input + self.output = output + self.untyped = untyped super().__init__( procedure=procedure, - implementation=implementation, protocol_handlers=mapped_method_handlers(protocol_handlers), allow_methods=sorted_allow_method_value(protocol_handlers), accept_post=sorted_accept_post_value(protocol_handlers), ) + + async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: + """Handle a bidirectional stream request and response. + + Args: + conn: The connection handler + timeout: Optional timeout in milliseconds + + """ + request = await receive_stream_request(conn, self.input) + if timeout: + request.timeout = timeout + + response = await self.untyped(request) + + conn.response_headers.update(response.headers) + conn.response_trailers.update(response.trailers) + + await conn.send(response.messages) diff --git a/src/connect/protocol.py b/src/connect/protocol.py index 2d2ae83..93ed608 100644 --- a/src/connect/protocol.py +++ b/src/connect/protocol.py @@ -16,7 +16,6 @@ StreamingHandlerConn, StreamType, UnaryClientConn, - UnaryHandlerConn, ) from connect.error import ConnectError from connect.headers import Headers @@ -108,7 +107,7 @@ def stream_conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: raise NotImplementedError() -HanderConn = UnaryHandlerConn | StreamingHandlerConn +HanderConn = StreamingHandlerConn class ProtocolHandler(abc.ABC): @@ -157,39 +156,24 @@ def can_handle_payload(self, request: Request, content_type: str) -> bool: @abc.abstractmethod async def conn( - self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter - ) -> UnaryHandlerConn | None: - """Handle a unary connection request. - - Args: - request (Request): The incoming request object. - response_headers (Headers): The headers to be sent in the response. - response_trailers (Headers): The trailers to be sent in the response. - writer (ServerResponseWriter): The writer used to send the response. - - Returns: - UnaryHandlerConn | None: The connection handler or None if not implemented. - - Raises: - NotImplementedError: If the method is not implemented. - - """ - raise NotImplementedError() - - @abc.abstractmethod - async def stream_conn( - self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter + self, + request: Request, + response_headers: Headers, + response_trailers: Headers, + writer: ServerResponseWriter, + is_streaming: bool = False, ) -> StreamingHandlerConn | None: - """Handle a streaming connection. + """Handle a connection request. Args: request (Request): The incoming request object. response_headers (Headers): The headers to be sent in the response. response_trailers (Headers): The trailers to be sent in the response. - writer (ServerResponseWriter): The writer object to send the response. + writer (ServerResponseWriter): The writer used to send the response. + is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. Returns: - StreamingHandlerConn | None: The streaming handler connection object or None if not implemented. + StreamingHandlerConn | None: The connection handler or None if not implemented. Raises: NotImplementedError: If the method is not implemented. diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index c7c1c04..7de43b5 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -26,13 +26,13 @@ from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name from connect.connect import ( Address, + AsyncContentStream, Peer, Spec, StreamingClientConn, StreamingHandlerConn, StreamType, UnaryClientConn, - UnaryHandlerConn, ) from connect.envelope import Envelope, EnvelopeFlags from connect.error import DEFAULT_ANY_RESOLVER_PREFIX, ConnectError, ErrorDetail @@ -175,116 +175,44 @@ def can_handle_payload(self, request: Request, content_type: str) -> bool: return content_type in self.accept - async def stream_conn( - self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter + async def conn( + self, + request: Request, + response_headers: Headers, + response_trailers: Headers, + writer: ServerResponseWriter, + is_streaming: bool = False, ) -> StreamingHandlerConn | None: - """Establish a streaming connection for the given request and response. + """Handle a connection request. Args: request (Request): The incoming request object. - response_headers (Headers): Headers to be sent in the response. - response_trailers (Headers): Trailers to be sent in the response. - writer (ServerResponseWriter): The writer to send the response. + response_headers (Headers): The headers to be sent in the response. + response_trailers (Headers): The trailers to be sent in the response. + writer (ServerResponseWriter): The writer used to send the response. + is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. Returns: - StreamingHandlerConn | None: The streaming connection handler if no error occurs, otherwise None. + StreamingHandlerConn | None: The connection handler or None if not implemented. Raises: ConnectError: If there is an error in negotiating compression, protocol version, or message encoding. - """ - content_encoding = request.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, None) - accept_encoding = request.headers.get(CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, None) - - request_compression, response_compression, error = negotiate_compression( - self.params.compressions, content_encoding, accept_encoding - ) - - if error is None: - required = self.params.require_connect_protocol_header and self.params.spec.stream_type == StreamType.Unary - error = connect_check_protocol_version(request, required) - - content_type = request.headers.get(HEADER_CONTENT_TYPE, "") - codec_name = connect_codec_from_content_type(self.params.spec.stream_type, content_type) - - codec = self.params.codecs.get(codec_name) - if error is None and codec is None: - error = ConnectError( - f"invalid message encoding: {codec_name}", - Code.INVALID_ARGUMENT, - ) - - response_headers[HEADER_CONTENT_TYPE] = content_type - - if response_compression and response_compression.name != COMPRESSION_IDENTITY: - response_headers[CONNECT_STREAMING_HEADER_COMPRESSION] = response_compression.name - - response_headers[CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION] = ( - f"{', '.join(c.name for c in self.params.compressions)}" - ) - - peer = Peer( - address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, - protocol=PROTOCOL_CONNECT, - query=request.query_params, - ) - - stream_conn = ConnectStreamingHandlerConn( - writer=writer, - request=request, - peer=peer, - spec=self.params.spec, - marshaler=ConnectStreamingMarshaler( - codec=codec, - compress_min_bytes=self.params.compress_min_bytes, - send_max_bytes=self.params.send_max_bytes, - compression=response_compression, - ), - unmarshaler=ConnectStreamingUnmarshaler( - stream=request.stream(), - codec=codec, - compression=request_compression, - read_max_bytes=self.params.read_max_bytes, - ), - request_headers=Headers(request.headers, encoding="latin-1"), - response_headers=response_headers, - response_trailers=response_trailers, - ) - - if error: - await stream_conn.send_error(error) - return None - - return stream_conn - - async def conn( - self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter - ) -> UnaryHandlerConn | None: - """Handle an incoming connection request and returns a UnaryHandlerConn object if successful. - - Args: - request (Request): The incoming request object. - response_headers (Headers): The headers to be included in the response. - response_trailers (Headers): The trailers to be included in the response. - writer (ServerResponseWriter): The writer object to send the response. - - Returns: - UnaryHandlerConn | None: A UnaryHandlerConn object if the connection is successfully established, - otherwise None if an error occurs. - - Raises: - ConnectError: If there are issues with the request parameters, encoding, or protocol version. - """ query_params = request.query_params - if HTTPMethod(request.method) == HTTPMethod.GET: - content_encoding = query_params.get(CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER, None) + # Get content encoding and accept encoding appropriately based on stream type + if is_streaming: + content_encoding = request.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, None) + accept_encoding = request.headers.get(CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, None) else: - content_encoding = request.headers.get(CONNECT_UNARY_HEADER_COMPRESSION, None) - - accept_encoding = request.headers.get(CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION, None) + if HTTPMethod(request.method) == HTTPMethod.GET: + content_encoding = query_params.get(CONNECT_UNARY_COMPRESSION_QUERY_PARAMETER, None) + else: + content_encoding = request.headers.get(CONNECT_UNARY_HEADER_COMPRESSION, None) + accept_encoding = request.headers.get(CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION, None) + # Negotiate compression request_compression, response_compression, error = negotiate_compression( self.params.compressions, content_encoding, accept_encoding ) @@ -293,7 +221,8 @@ async def conn( required = self.params.require_connect_protocol_header and self.params.spec.stream_type == StreamType.Unary error = connect_check_protocol_version(request, required) - if HTTPMethod(request.method) == HTTPMethod.GET: + # Process GET parameters for unary requests + if not is_streaming and HTTPMethod(request.method) == HTTPMethod.GET: encoding = query_params.get(CONNECT_UNARY_ENCODING_QUERY_PARAMETER, "") message = query_params.get(CONNECT_UNARY_MESSAGE_QUERY_PARAMETER, "") if error is None and encoding == "": @@ -325,6 +254,7 @@ async def conn( content_type = request.headers.get(HEADER_CONTENT_TYPE, "") codec_name = connect_codec_from_content_type(self.params.spec.stream_type, content_type) + # Get codec codec = self.params.codecs.get(codec_name) if error is None and codec is None: error = ConnectError( @@ -332,40 +262,76 @@ async def conn( Code.INVALID_ARGUMENT, ) + # Set response headers response_headers[HEADER_CONTENT_TYPE] = content_type - response_headers[CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION] = ( - f"{', '.join(c.name for c in self.params.compressions)}" - ) + if is_streaming: + if response_compression and response_compression.name != COMPRESSION_IDENTITY: + response_headers[CONNECT_STREAMING_HEADER_COMPRESSION] = response_compression.name + + response_headers[CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION] = ( + f"{', '.join(c.name for c in self.params.compressions)}" + ) + else: + response_headers[CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION] = ( + f"{', '.join(c.name for c in self.params.compressions)}" + ) + # Create peer peer = Peer( address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, protocol=PROTOCOL_CONNECT, query=request.query_params, ) - conn = ConnectUnaryHandlerConn( - writer=writer, - request=request, - peer=peer, - spec=self.params.spec, - marshaler=ConnectUnaryMarshaler( - codec=codec, - compress_min_bytes=self.params.compress_min_bytes, - send_max_bytes=self.params.send_max_bytes, - compression=response_compression, - headers=response_headers, - ), - unmarshaler=ConnectUnaryUnmarshaler( - stream=request_stream, - codec=codec, - compression=request_compression, - read_max_bytes=self.params.read_max_bytes, - ), - request_headers=Headers(request.headers, encoding="latin-1"), - response_headers=response_headers, - response_trailers=response_trailers, - ) + # Create appropriate handler connection based on stream type + conn: StreamingHandlerConn | None = None + if is_streaming: + conn = ConnectStreamingHandlerConn( + writer=writer, + request=request, + peer=peer, + spec=self.params.spec, + marshaler=ConnectStreamingMarshaler( + codec=codec, + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + compression=response_compression, + ), + unmarshaler=ConnectStreamingUnmarshaler( + stream=request.stream(), + codec=codec, + compression=request_compression, + read_max_bytes=self.params.read_max_bytes, + ), + request_headers=Headers(request.headers, encoding="latin-1"), + response_headers=response_headers, + response_trailers=response_trailers, + ) + else: + conn = ConnectUnaryHandlerConn( + writer=writer, + request=request, + peer=peer, + spec=self.params.spec, + marshaler=ConnectUnaryMarshaler( + codec=codec, + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + compression=response_compression, + headers=response_headers, + ), + unmarshaler=ConnectUnaryUnmarshaler( + stream=request_stream, + codec=codec, + compression=request_compression, + read_max_bytes=self.params.read_max_bytes, + ), + request_headers=Headers(request.headers, encoding="latin-1"), + response_headers=response_headers, + response_trailers=response_trailers, + ) + if error: await conn.send_error(error) return None @@ -620,7 +586,7 @@ def marshal(self, message: Any) -> bytes: return data -class ConnectUnaryHandlerConn(UnaryHandlerConn): +class ConnectUnaryHandlerConn(StreamingHandlerConn): """ConnectUnaryHandlerConn is a handler connection class for unary RPCs in the Connect protocol. Attributes: @@ -709,18 +675,33 @@ def peer(self) -> Peer: """ return self._peer - async def receive(self, message: Any) -> Any: - """Receives a message, unmarshals it, and returns the resulting object. + async def receive_message(self, message: Any) -> AsyncIterator[Any]: + """Receives and unmarshals a message into an object. Args: message (Any): The message to be unmarshaled. Returns: - Any: The unmarshaled object. + AsyncIterator[Any]: An async iterator yielding the unmarshaled object. """ obj = await self.unmarshaler.unmarshal(message) - return obj + yield obj + + def receive(self, message: Any) -> AsyncContentStream[Any]: + """Receives a message, unmarshals it, and returns the resulting object. + + Args: + message (Any): The message to be unmarshaled. + + Returns: + AsyncIterator[Any]: An async iterator yielding the unmarshaled object. + + """ + return AsyncContentStream( + self.receive_message(message), + stream_type=self.spec.stream_type, + ) @property def request_headers(self) -> Headers: @@ -732,18 +713,32 @@ def request_headers(self) -> Headers: """ return self._request_headers - async def send(self, message: Any) -> None: - """Send a message by marshaling it into bytes. + async def send(self, messages: AsyncIterable[Any]) -> None: + """Send message(s) by marshaling them into bytes. Args: - message (Any): The message to be sent. + messages (AsyncIterable[Any]): The message(s) to be sent. For unary operations, + this should be an iterable with a single item. Returns: - bytes: The marshaled message in bytes. + None """ self.merge_response_trailers() + message = None + count = 0 + + async for msg in messages: + count += 1 + if count > 1: + raise ConnectError("unary handler should only send one message", Code.INTERNAL) + + message = msg + + if message is None: + raise ConnectError("unary handler must send one message", Code.INTERNAL) + data = self.marshaler.marshal(message) response = Response(content=data, headers=self.response_headers, status_code=HTTPStatus.OK) await self.writer.write(response) @@ -1543,19 +1538,44 @@ def peer(self) -> Peer: """ return self._peer - async def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message, unmarshals it, and returns the resulting object. + async def _receive_message(self, message: Any) -> AsyncIterator[Any]: + """Asynchronously receives a message and yields unmarshaled objects. + + This method unmarshals the received message and yields each + unmarshaled object one by one as an asynchronous iterator. Args: - message (Any): The message to be unmarshaled. + message (Any): The message to unmarshal. Returns: - Any: The unmarshaled object. + AsyncIterator[Any]: An asynchronous iterator yielding unmarshaled objects. + + Yields: + Any: Each unmarshaled object from the message. """ async for obj, _ in self.unmarshaler.unmarshal(message): yield obj + def receive(self, message: Any) -> AsyncContentStream[Any]: + """Receives a message and returns an asynchronous content stream. + + This method processes the incoming message through the receive_message method + and wraps the result in an AsyncContentStream with the appropriate stream type. + + Args: + message (Any): The message to be processed. + + Returns: + AsyncContentStream[Any]: An asynchronous stream of content based on the + processed message, configured with the specification's stream type. + + """ + return AsyncContentStream( + iterable=self._receive_message(message), + stream_type=self.spec.stream_type, + ) + @property def request_headers(self) -> Headers: """Retrieve the headers from the request. @@ -1566,6 +1586,32 @@ def request_headers(self) -> Headers: """ return self._request_headers + async def create_message_iterator(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: + """Create an async iterator that marshals messages with error handling. + + Args: + messages (AsyncIterable[Any]): Messages to marshal + + Returns: + AsyncIterator[bytes]: Marshaled bytes with end stream message + + Yields: + bytes: Each marshaled message followed by an end stream message + + """ + error: ConnectError | None = None + try: + async for message in self.marshaler.marshal(messages): + yield message + except Exception as e: + error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) + finally: + json_obj = end_stream_to_json(error, self.response_trailers) + json_str = json.dumps(json_obj) + + body = self.marshaler.marshal_end_stream(json_str.encode()) + yield body + async def send(self, messages: AsyncIterable[Any]) -> None: """Send a stream of messages asynchronously. @@ -1583,24 +1629,9 @@ async def send(self, messages: AsyncIterable[Any]) -> None: ConnectError: If an error occurs during the marshaling process. """ - - async def iterator() -> AsyncIterator[bytes]: - error: ConnectError | None = None - try: - async for message in self.marshaler.marshal(messages): - yield message - except Exception as e: - error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) - finally: - json_obj = end_stream_to_json(error, self.response_trailers) - json_str = json.dumps(json_obj) - - body = self.marshaler.marshal_end_stream(json_str.encode()) - yield body - await self.writer.write( StreamingResponse( - content=iterator(), + content=self.create_message_iterator(messages), headers=self.response_headers, status_code=200, ) @@ -2053,18 +2084,30 @@ def peer(self) -> Peer: """ return self._peer - async def receive(self, message: Any) -> Any: - """Asynchronously receives a message, unmarshals it, and returns the resulting object. + async def _receive_message(self, message: Any) -> AsyncIterator[Any]: + """Asynchronously receives and unmarshals a message, yielding the resulting object. Args: message (Any): The message to be unmarshaled. - Returns: - None: This method does not return a value. The unmarshaled object is returned implicitly. + Yields: + Any: The unmarshaled object. """ obj = await self.unmarshaler.unmarshal(message) - return obj + yield obj + + def receive(self, message: Any) -> AsyncIterator[Any]: + """Receives a message and returns an asynchronous iterator over the processed message. + + Args: + message (Any): The message to be received and processed. + + Returns: + AsyncIterator[Any]: An asynchronous iterator yielding processed message(s). + + """ + return self._receive_message(message) @property def request_headers(self) -> Headers: @@ -2281,6 +2324,18 @@ def event_hooks(self, event_hooks: dict[str, list[EventHook]]) -> None: "response": list(event_hooks.get("response", [])), } + async def aclose(self) -> None: + """Asynchronously closes the connection or releases any resources held by the object. + + This method should be called when the object is no longer needed to ensure proper cleanup. + Currently, this implementation does not perform any actions, but it can be overridden in subclasses. + + Returns: + None + + """ + return + def connect_validate_unary_response_content_type( request_codec_name: str, diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index 1d5fd3e..ef106bd 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -1,3 +1,5 @@ +"""Provaides classes and functions for handling gRPC protocol.""" + import base64 import re import urllib.parse @@ -10,7 +12,7 @@ from connect.code import Code from connect.codec import Codec, CodecNameType from connect.compression import COMPRESSION_IDENTITY, Compression -from connect.connect import Address, Peer, Spec, StreamingHandlerConn, UnaryHandlerConn +from connect.connect import Address, AsyncContentStream, Peer, Spec, StreamingHandlerConn, StreamType from connect.envelope import EnvelopeReader, EnvelopeWriter from connect.error import ConnectError from connect.headers import Headers @@ -58,12 +60,38 @@ class ProtocolGPRC(Protocol): - web: bool + """ProtocolGPRC is a protocol implementation for handling gRPC and gRPC-Web requests. + + Attributes: + web (bool): Indicates whether to use gRPC-Web (True) or standard gRPC (False). + + """ def __init__(self, web: bool) -> None: + """Initialize the instance. + + Args: + web (bool): Indicates whether the instance is for web usage. + + """ self.web = web def handler(self, params: ProtocolHandlerParams) -> ProtocolHandler: + """Create and returns a GRPCHandler instance configured with appropriate content types based on the provided parameters. + + Args: + params (ProtocolHandlerParams): The parameters containing codec information and other handler configuration. + + Returns: + ProtocolHandler: An instance of GRPCHandler initialized with the correct content types for gRPC or gRPC-Web. + + Behavior: + - Determines the default and prefix content types based on whether gRPC-Web is enabled. + - Constructs a list of supported content types from the available codecs. + - Adds the bare content type if the PROTO codec is present. + - Returns a GRPCHandler with the computed content types. + + """ bare, prefix = GRPC_CONTENT_TYPE_DEFAULT, GRPC_CONTENT_TYPE_PREFIX if self.web: bare, prefix = GRPC_WEB_CONTENT_TYPE_DEFAULT, GRPC_WEB_CONTENT_TYPE_PREFIX @@ -78,89 +106,107 @@ def handler(self, params: ProtocolHandlerParams) -> ProtocolHandler: return GRPCHandler(params, self.web, content_types) def client(self, params: ProtocolClientParams) -> ProtocolClient: - """Implement client functionality. + """Create and return a GRPCClient instance. + + Args: + params (ProtocolClientParams): The parameters required to initialize the client. + + Returns: + ProtocolClient: An instance of GRPCClient. - This method currently does nothing and is intended to be implemented - in the future with the necessary client-side logic. """ - raise NotImplementedError() + raise NotImplementedError("GRPC client is not implemented yet.") class GRPCHandler(ProtocolHandler): + """GRPCHandler is a protocol handler for gRPC and gRPC-Web requests. + + This class implements the ProtocolHandler interface to handle gRPC protocol requests, + including negotiation of compression, codec selection, and connection management for + both standard gRPC and gRPC-Web. It supports content type negotiation, payload handling, + and manages the lifecycle of a gRPC connection, including streaming and non-streaming + requests. + + Attributes: + params (ProtocolHandlerParams): Configuration parameters for the handler, including codecs and compressions. + web (bool): Indicates if the handler is for gRPC-Web. + accept (list[str]): List of accepted content types. + + """ + params: ProtocolHandlerParams web: bool accept: list[str] def __init__(self, params: ProtocolHandlerParams, web: bool, accept: list[str]) -> None: + """Initialize the ProtocolHandler with the given parameters. + + Args: + params (ProtocolHandlerParams): The parameters required for the protocol handler. + web (bool): Indicates whether the handler is for web usage. + accept (list[str]): A list of accepted content types. + + Returns: + None + + """ self.params = params self.web = web self.accept = accept @property def methods(self) -> list[HTTPMethod]: + """Returns a list of allowed HTTP methods for gRPC protocol. + + Returns: + list[HTTPMethod]: A list containing the HTTP methods permitted for gRPC communication. + + """ return GRPC_ALLOWED_METHODS def content_types(self) -> list[str]: + """Return a list of accepted content types. + + Returns: + list[str]: A list of MIME types that are accepted. + + """ return self.accept def can_handle_payload(self, _: Request, content_type: str) -> bool: - return content_type in self.accept + """Determine if the given content type is supported by this handler. - async def conn( - self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter - ) -> UnaryHandlerConn | None: - content_encoding = request.headers.get(GRPC_HEADER_COMPRESSION) - accept_encoding = request.headers.get(GRPC_HEADER_ACCEPT_COMPRESSION) + Args: + _ (Request): The request object (unused). + content_type (str): The MIME type of the payload to check. - request_compression, response_compression, error = negotiate_compression( - self.params.compressions, content_encoding, accept_encoding - ) + Returns: + bool: True if the content type is accepted, False otherwise. - response_headers[HEADER_CONTENT_TYPE] = request.headers.get(HEADER_CONTENT_TYPE, "") - response_headers[GRPC_HEADER_ACCEPT_COMPRESSION] = f"{', '.join(c.name for c in self.params.compressions)}" - if response_compression and response_compression.name != COMPRESSION_IDENTITY: - response_headers[GRPC_HEADER_COMPRESSION] = response_compression.name + """ + return content_type in self.accept - codec_name = grpc_codec_from_content_type(self.web, request.headers.get(HEADER_CONTENT_TYPE, "")) - codec = self.params.codecs.get(codec_name) - protocol_name = PROTOCOL_GRPC if not self.web else PROTOCOL_GRPC + "-web" + async def conn( + self, + request: Request, + response_headers: Headers, + response_trailers: Headers, + writer: ServerResponseWriter, + is_streaming: bool = False, + ) -> StreamingHandlerConn | None: + """Handle a connection request. - peer = Peer( - address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, - protocol=protocol_name, - query=request.query_params, - ) + Args: + request (Request): The incoming request object. + response_headers (Headers): The headers to be sent in the response. + response_trailers (Headers): The trailers to be sent in the response. + writer (ServerResponseWriter): The writer used to send the response. + is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. - conn = GRPCHandlerConn( - writer=writer, - spec=self.params.spec, - peer=peer, - marshaler=GRPCMarshaler( - self.web, - codec, - response_compression, - self.params.compress_min_bytes, - self.params.send_max_bytes, - ), - unmarshaler=GRPCUnmarshaler( - codec, - self.params.read_max_bytes, - request.stream(), - request_compression, - ), - request_headers=Headers(request.headers, encoding="latin-1"), - response_headers=response_headers, - response_trailers=response_trailers, - ) - if error: - await conn.send_error(error) - return None + Returns: + StreamingHandlerConn | None: The connection handler or None if not implemented. - return conn - - async def stream_conn( - self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter - ) -> StreamingHandlerConn | None: + """ content_encoding = request.headers.get(GRPC_HEADER_COMPRESSION) accept_encoding = request.headers.get(GRPC_HEADER_ACCEPT_COMPRESSION) @@ -183,7 +229,7 @@ async def stream_conn( query=request.query_params, ) - conn = GRPCStreamingHandlerConn( + conn = GRPCHandlerConn( writer=writer, spec=self.params.spec, peer=peer, @@ -203,7 +249,9 @@ async def stream_conn( request_headers=Headers(request.headers, encoding="latin-1"), response_headers=response_headers, response_trailers=response_trailers, + is_streaming=is_streaming, ) + if error: await conn.send_error(error) return None @@ -212,6 +260,22 @@ async def stream_conn( class GRPCMarshaler(EnvelopeWriter): + """GRPCMarshaler is responsible for marshaling messages into the gRPC wire format. + + Args: + web (bool): Indicates whether to use the gRPC-web protocol. + codec (Codec | None): The codec used for encoding/decoding messages. + compression (Compression | None): The compression algorithm to use, if any. + compress_min_bytes (int): Minimum message size in bytes before compression is applied. + send_max_bytes (int): Maximum allowed size of a message to send. + + Methods: + marshal(messages: AsyncIterable[bytes]) -> AsyncIterator[bytes]: + Asynchronously marshals a stream of message bytes into the gRPC wire format. + Yields marshaled message bytes ready for transmission. + + """ + web: bool def __init__( @@ -222,15 +286,52 @@ def __init__( compress_min_bytes: int, send_max_bytes: int, ) -> None: + """Initialize the protocol with the specified configuration. + + Args: + web (bool): Indicates whether the protocol is used in a web context. + codec (Codec | None): The codec to use for encoding/decoding messages, or None for default. + compression (Compression | None): The compression algorithm to use, or None for no compression. + compress_min_bytes (int): The minimum number of bytes before compression is applied. + send_max_bytes (int): The maximum number of bytes allowed to send in a single message. + + Returns: + None + + """ super().__init__(codec, compression, compress_min_bytes, send_max_bytes) self.web = web async def marshal(self, messages: AsyncIterable[bytes]) -> AsyncIterator[bytes]: + """Asynchronously marshals a stream of byte messages. + + Args: + messages (AsyncIterable[bytes]): An asynchronous iterable of byte messages to be marshaled. + + Yields: + AsyncIterator[bytes]: An asynchronous iterator yielding marshaled byte messages. + + """ async for message in self._marshal(messages): yield message class GRPCUnmarshaler(EnvelopeReader): + """GRPCUnmarshaler is a specialized EnvelopeReader for handling gRPC message unmarshaling. + + Args: + codec (Codec | None): The codec used for decoding messages. + read_max_bytes (int): The maximum number of bytes to read from the stream. + stream (AsyncIterable[bytes] | None, optional): The asynchronous byte stream to read messages from. + compression (Compression | None, optional): Compression algorithm to use for decompressing messages. + + Methods: + async unmarshal(message: Any) -> AsyncIterator[Any]: + Asynchronously unmarshals the given message, yielding each decoded object. + Iterates over the results of the internal _unmarshal method, yielding only the object part of each tuple. + + """ + def __init__( self, codec: Codec | None, @@ -238,14 +339,47 @@ def __init__( stream: AsyncIterable[bytes] | None = None, compression: Compression | None = None, ) -> None: + """Initialize the protocol gRPC handler. + + Args: + codec (Codec | None): The codec to use for encoding/decoding messages. Can be None. + read_max_bytes (int): The maximum number of bytes to read from the stream. + stream (AsyncIterable[bytes] | None, optional): An asynchronous iterable stream of bytes. Defaults to None. + compression (Compression | None, optional): The compression method to use. Defaults to None. + + """ super().__init__(codec, read_max_bytes, stream, compression) async def unmarshal(self, message: Any) -> AsyncIterator[Any]: + """Asynchronously unmarshals a given message and yields each resulting object. + + Args: + message (Any): The message to be unmarshaled. + + Yields: + Any: Each object obtained from unmarshaling the message. + + """ async for obj, _ in self._unmarshal(message): yield obj -class GRPCHandlerConn(UnaryHandlerConn): +class GRPCHandlerConn(StreamingHandlerConn): + """GRPCHandlerConn is a handler class for managing gRPC protocol connections within a streaming server context. + + This class encapsulates the logic for handling gRPC requests and responses, including marshaling and unmarshaling messages, + managing request and response headers/trailers, handling timeouts, and enforcing protocol-specific constraints for unary and streaming operations. + + Attributes: + _spec (Spec): The specification object describing the protocol or service. + _peer (Peer): The peer information for the current connection. + _request_headers (Headers): The headers received with the request. + _response_headers (Headers): The headers to include in the response. + _response_trailers (Headers): The trailers to include in the response. + _is_streaming (bool): Indicates if the connection is streaming. + + """ + _spec: Spec _peer: Peer writer: ServerResponseWriter @@ -254,6 +388,7 @@ class GRPCHandlerConn(UnaryHandlerConn): _request_headers: Headers _response_headers: Headers _response_trailers: Headers + _is_streaming: bool def __init__( self, @@ -265,7 +400,22 @@ def __init__( request_headers: Headers, response_headers: Headers, response_trailers: Headers | None = None, + is_streaming: bool = False, ) -> None: + """Initialize a new instance of the class. + + Args: + writer (ServerResponseWriter): The writer used to send responses to the client. + spec (Spec): The specification object describing the protocol or service. + peer (Peer): The peer information for the current connection. + marshaler (GRPCMarshaler): The marshaler used to serialize response messages. + unmarshaler (GRPCUnmarshaler): The unmarshaler used to deserialize request messages. + request_headers (Headers): The headers received with the request. + response_headers (Headers): The headers to include in the response. + response_trailers (Headers | None, optional): The trailers to include in the response. Defaults to None. + is_streaming (bool, optional): Indicates if the connection is streaming. Defaults to False. + + """ self.writer = writer self._spec = spec self._peer = peer @@ -274,8 +424,22 @@ def __init__( self._request_headers = request_headers self._response_headers = response_headers self._response_trailers = response_trailers if response_trailers is not None else Headers() + self._is_streaming = is_streaming def parse_timeout(self) -> float | None: + """Parse the gRPC timeout value from the request headers and returns it as seconds. + + Returns: + float | None: The timeout value in seconds if present and valid, otherwise None. + + Raises: + ConnectError: If the timeout value is present but invalid or too long. + + Notes: + - The timeout is extracted from the gRPC header and must match the expected format. + - If the timeout unit is hours and exceeds the maximum allowed, None is returned. + + """ timeout = self._request_headers.get(GRPC_HEADER_TIMEOUT) if not timeout: return None @@ -298,44 +462,75 @@ def parse_timeout(self) -> float | None: @property def spec(self) -> Spec: + """Returns the specification object associated with this instance. + + Returns: + Spec: The specification object. + + """ return self._spec @property def peer(self) -> Peer: + """Returns the associated Peer object. + + Returns: + Peer: The peer instance associated with this object. + + """ return self._peer - async def receive(self, message: Any) -> Any: - first = None - async for obj in self.unmarshaler.unmarshal(message): - # TODO(tsubakiky): validation - if first is None: - first = obj - else: - raise ConnectError("protocol error: expected only one message, but got multiple", Code.UNIMPLEMENTED) + def receive(self, message: Any) -> AsyncContentStream[Any]: + """Receives a message and processes it. - if first is None: - raise ConnectError("protocol error: expected one message, but got none", Code.UNIMPLEMENTED) + Args: + message (Any): The message to be received and processed. - return first + Returns: + AsyncIterator[Any]: An async iterator yielding message(s). For non-streaming operations, + this will yield exactly one item. + + """ + return AsyncContentStream(self.unmarshaler.unmarshal(message), self.spec.stream_type) @property def request_headers(self) -> Headers: + """Returns the headers associated with the current request. + + Returns: + Headers: The headers of the request. + + """ return self._request_headers - async def send(self, message: Any) -> None: - async def iterator() -> AsyncIterator[bytes]: - error: ConnectError | None = None - try: - async for msg in self.marshaler.marshal(aiterate([message])): - yield msg - except Exception as e: - error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) - finally: - grpc_error_to_trailer(self.response_trailers, error) + async def send(self, messages: AsyncIterable[Any]) -> None: + """Send message(s) by marshaling them into bytes. + + Args: + messages (AsyncIterable[Any]): The message(s) to be sent. For unary operations, + this should be an iterable with a single item. + + Returns: + None + """ + # Validation for unary streams - ensure exactly one message + if not self._is_streaming and self.spec.stream_type == StreamType.Unary: + message_list = [] + async for msg in messages: + message_list.append(msg) + + if len(message_list) != 1: + raise ConnectError( + f"unary handler expected to send exactly one message, got {len(message_list)}", Code.INTERNAL + ) + + messages = aiterate(message_list) + + # Use the common marshaling method await self.writer.write( StreamingResponseWithTrailers( - content=iterator(), + content=self.marshal_with_error_handling(messages), headers=self.response_headers, trailers=self.response_trailers, status_code=200, @@ -344,119 +539,59 @@ async def iterator() -> AsyncIterator[bytes]: @property def response_headers(self) -> Headers: + """Returns the response headers associated with the current request. + + Returns: + Headers: The headers returned in the response. + + """ return self._response_headers @property def response_trailers(self) -> Headers: - return self._response_trailers + """Returns the response trailers as headers. - async def send_error(self, error: ConnectError) -> None: - grpc_error_to_trailer(self.response_trailers, error) + Response trailers are additional metadata sent by the server after the response body, + typically used in gRPC and HTTP/2 protocols. - await self.writer.write( - StreamingResponseWithTrailers( - content=aiterate([b""]), headers=self.response_headers, trailers=self.response_trailers, status_code=200 - ) - ) + Returns: + Headers: The response trailers associated with the current response. + """ + return self._response_trailers -class GRPCStreamingHandlerConn(StreamingHandlerConn): - _spec: Spec - _peer: Peer - writer: ServerResponseWriter - marshaler: GRPCMarshaler - unmarshaler: GRPCUnmarshaler - _request_headers: Headers - _response_headers: Headers - _response_trailers: Headers - - def __init__( - self, - writer: ServerResponseWriter, - spec: Spec, - peer: Peer, - marshaler: GRPCMarshaler, - unmarshaler: GRPCUnmarshaler, - request_headers: Headers, - response_headers: Headers, - response_trailers: Headers | None = None, - ) -> None: - self.writer = writer - self._spec = spec - self._peer = peer - self.marshaler = marshaler - self.unmarshaler = unmarshaler - self._request_headers = request_headers - self._response_headers = response_headers - self._response_trailers = response_trailers if response_trailers is not None else Headers() - - def parse_timeout(self) -> float | None: - timeout = self._request_headers.get(GRPC_HEADER_TIMEOUT) - if not timeout: - return None - - m = _RE.match(timeout) - if m is None: - raise ConnectError(f"protocol error: invalid grpc timeout value: {timeout}") - - num_str, unit = m.groups() - num = int(num_str) - - if num > 99_999_999: - raise ConnectError(f"protocol error: timeout {timeout!r} is too long") - - if unit == "H" and num > _MAX_HOURS: - return None - - seconds = num * _UNIT_TO_SECONDS[unit] - return seconds - - @property - def spec(self) -> Spec: - return self._spec + async def marshal_with_error_handling(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: + """Marshal messages to bytes with error handling. - @property - def peer(self) -> Peer: - return self._peer + Args: + messages (AsyncIterable[Any]): The messages to marshal - async def receive(self, message: Any) -> AsyncIterator[Any]: - async for obj in self.unmarshaler.unmarshal(message): - # TODO(tsubakiky): validation - yield obj + Returns: + AsyncIterator[bytes]: An async iterator of marshaled bytes - @property - def request_headers(self) -> Headers: - return self._request_headers + """ + error: ConnectError | None = None + try: + async for msg in self.marshaler.marshal(messages): + yield msg + except Exception as e: + error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) + finally: + grpc_error_to_trailer(self.response_trailers, error) - async def send(self, messages: AsyncIterable[Any]) -> None: - async def iterator() -> AsyncIterator[bytes]: - error: ConnectError | None = None - try: - async for msg in self.marshaler.marshal(messages): - yield msg - except Exception as e: - error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) - finally: - grpc_error_to_trailer(self.response_trailers, error) + async def send_error(self, error: ConnectError) -> None: + """Send an error response over gRPC by converting the provided ConnectError into gRPC trailers. - await self.writer.write( - StreamingResponseWithTrailers( - content=iterator(), - headers=self.response_headers, - trailers=self.response_trailers, - status_code=200, - ) - ) + Args: + error (ConnectError): The error to be sent as a gRPC trailer. - @property - def response_headers(self) -> Headers: - return self._response_headers + Returns: + None - @property - def response_trailers(self) -> Headers: - return self._response_trailers + This method updates the response trailers with the error information and writes a streaming response + with the appropriate headers and trailers to the client. - async def send_error(self, error: ConnectError) -> None: + """ grpc_error_to_trailer(self.response_trailers, error) await self.writer.write( @@ -467,6 +602,18 @@ async def send_error(self, error: ConnectError) -> None: def grpc_codec_from_content_type(web: bool, content_type: str) -> str: + """Determine the gRPC codec name from the given content type string. + + Args: + web (bool): Indicates whether the request is a gRPC-web request. + content_type (str): The content type string to parse. + + Returns: + str: The codec name extracted from the content type. If the content type matches the default gRPC or gRPC-web content type, + returns the default codec name. Otherwise, extracts and returns the codec name from the content type prefix, or returns + the original content type if no known prefix is found. + + """ if (not web and content_type == GRPC_CONTENT_TYPE_DEFAULT) or ( web and content_type == GRPC_WEB_CONTENT_TYPE_DEFAULT ): @@ -481,6 +628,21 @@ def grpc_codec_from_content_type(web: bool, content_type: str) -> str: def grpc_error_to_trailer(trailer: Headers, error: ConnectError | None) -> None: + """Convert a ConnectError to gRPC trailer headers. + + Args: + trailer (Headers): The trailer headers dictionary to update with gRPC error information. + error (ConnectError | None): The error to convert. If None, indicates success. + + Side Effects: + Modifies the `trailer` dictionary in-place to include gRPC status, message, and optional details. + + Notes: + - If `error` is None, sets the gRPC status header to "0" (OK). + - If `ConnectError.wire_error` is False, updates the trailer with error metadata excluding protocol headers. + - Serializes error details using protobuf if present, encoding them in base64 for the trailer. + + """ if error is None: trailer[GRPC_HEADER_STATUS] = "0" return diff --git a/src/connect/response_trailer.py b/src/connect/response_trailer.py index 66787a1..a566ef7 100644 --- a/src/connect/response_trailer.py +++ b/src/connect/response_trailer.py @@ -1,4 +1,4 @@ -from __future__ import annotations +"""Streaming HTTP response with support for trailers.""" import typing from functools import partial @@ -17,6 +17,21 @@ class StreamingResponseWithTrailers(Response): + """A streaming HTTP response class that supports HTTP trailers. + + This class extends the standard response to allow sending HTTP trailers + at the end of a streamed response body, if supported by the ASGI server. + + Attributes: + body_iterator (AsyncContentStream): An asynchronous iterator over the response body content. + status_code (int): HTTP status code for the response. + media_type (str | None): The media type of the response. + background (BackgroundTask | None): Optional background task to run after response is sent. + headers (Mapping[str, str]): HTTP headers for the response. + _trailers (Mapping[str, str] | None): HTTP trailers to send after the response body. + + """ + body_iterator: AsyncContentStream def __init__( @@ -29,6 +44,21 @@ def __init__( media_type: str | None = None, background: BackgroundTask | None = None, ) -> None: + """Initialize a response object with optional HTTP trailers. + + Args: + content (ContentStream): The response body content, which can be an async iterable or a regular iterable. + status_code (int, optional): HTTP status code for the response. Defaults to 200. + headers (typing.Mapping[str, str] | None, optional): HTTP headers to include in the response. Defaults to None. + trailers (typing.Mapping[str, str] | None, optional): HTTP trailers to include in the response. Defaults to None. + media_type (str | None, optional): The media type of the response. If None, uses the default media type. Defaults to None. + background (BackgroundTask | None, optional): A background task to run after the response is sent. Defaults to None. + + Notes: + - If `content` is not an async iterable, it will be wrapped to run in a thread pool. + - If trailers are provided, their names will be added to the "Trailer" header. + + """ if isinstance(content, typing.AsyncIterable): self.body_iterator = content else: @@ -69,6 +99,24 @@ async def _stream_response(self, send: Send, trailers_supported: bool) -> None: }) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Handle the ASGI call interface for streaming HTTP responses with optional support for HTTP trailers. + + This method determines the ASGI spec version and whether HTTP response trailers are supported. + For ASGI spec version >= 2.4, it streams the response and handles client disconnects. + For earlier versions, it concurrently streams the response and listens for client disconnects, + cancelling the response stream if a disconnect is detected. + + After sending the response, if a background task is provided, it is awaited. + + Args: + scope (Scope): The ASGI connection scope. + receive (Receive): Awaitable callable to receive ASGI messages. + send (Send): Awaitable callable to send ASGI messages. + + Raises: + ClientDisconnect: If the client disconnects during response streaming. + + """ spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) trailers_supported = "http.response.trailers" in scope.get("extensions", {})