diff --git a/conformance/client_config.yaml b/conformance/client_config.yaml index 0ad7321..796997a 100644 --- a/conformance/client_config.yaml +++ b/conformance/client_config.yaml @@ -5,6 +5,7 @@ features: protocols: - PROTOCOL_CONNECT - PROTOCOL_GRPC + - PROTOCOL_GRPC_WEB codecs: - CODEC_PROTO compressions: @@ -15,7 +16,6 @@ features: - STREAM_TYPE_CLIENT_STREAM - STREAM_TYPE_SERVER_STREAM - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM - # - STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM supports_h2c: true supports_tls: true diff --git a/conformance/server_config.yaml b/conformance/server_config.yaml index 7c16099..b4004bb 100644 --- a/conformance/server_config.yaml +++ b/conformance/server_config.yaml @@ -5,6 +5,7 @@ features: protocols: - PROTOCOL_CONNECT - PROTOCOL_GRPC + - PROTOCOL_GRPC_WEB codecs: - CODEC_PROTO - CODEC_JSON diff --git a/src/connect/envelope.py b/src/connect/envelope.py index 01bf131..a8630f2 100644 --- a/src/connect/envelope.py +++ b/src/connect/envelope.py @@ -23,6 +23,7 @@ class EnvelopeFlags(Flag): compressed = 0b00000001 end_stream = 0b00000010 + trailer = 0b10000000 class Envelope: @@ -253,7 +254,7 @@ class EnvelopeReader: stream: AsyncIterable[bytes] | None buffer: bytes bytes_read: int - last_data: bytes | None + last: Envelope | None def __init__( self, @@ -277,7 +278,7 @@ def __init__( self.stream = stream self.buffer = b"" self.bytes_read = 0 - self.last_data = None + self.last = None async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: """Asynchronously unmarshals messages from the stream. @@ -325,7 +326,7 @@ async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: env.data = self.compression.decompress(env.data, self.read_max_bytes) if env.flags != EnvelopeFlags(0) and env.flags != EnvelopeFlags.compressed: - self.last_data = env.data + self.last = env end = True obj = None else: diff --git a/src/connect/handler.py b/src/connect/handler.py index cdb51ce..69ddb51 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -123,7 +123,7 @@ def create_protocol_handlers(config: HandlerConfig) -> list[ProtocolHandler]: list[ProtocolHandler]: A list of initialized protocol handlers. """ - protocols = [ProtocolConnect(), ProtocolGRPC(web=False)] + protocols = [ProtocolConnect(), ProtocolGRPC(web=False), ProtocolGRPC(web=True)] codecs = CodecMap(config.codecs) diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index dc963fa..f7a7c60 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -1189,8 +1189,8 @@ async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: """ async for obj, end in super().unmarshal(message): - if self.last_data: - error, trailers = end_stream_from_bytes(self.last_data) + if self.last: + error, trailers = end_stream_from_bytes(self.last.data) self._end_stream_error = error self._trailers = trailers diff --git a/src/connect/protocol_grpc.py b/src/connect/protocol_grpc.py index a7e02e7..750c659 100644 --- a/src/connect/protocol_grpc.py +++ b/src/connect/protocol_grpc.py @@ -8,6 +8,7 @@ import sys import urllib.parse from collections.abc import AsyncIterable, AsyncIterator, Callable, Mapping +from copy import copy from http import HTTPMethod from typing import Any from urllib.parse import unquote @@ -29,7 +30,7 @@ StreamingHandlerConn, StreamType, ) -from connect.envelope import EnvelopeReader, EnvelopeWriter +from connect.envelope import EnvelopeFlags, EnvelopeReader, EnvelopeWriter from connect.error import ConnectError, ErrorDetail from connect.headers import Headers, include_request_headers from connect.protocol import ( @@ -49,7 +50,7 @@ from connect.request import Request from connect.session import AsyncClientSession from connect.streaming_response import StreamingResponse -from connect.utils import map_httpcore_exceptions +from connect.utils import aiterate, map_httpcore_exceptions from connect.version import __version__ from connect.writer import ServerResponseWriter @@ -262,6 +263,7 @@ async def conn( ) conn = GRPCHandlerConn( + web=self.web, writer=writer, spec=self.params.spec, peer=peer, @@ -272,6 +274,7 @@ async def conn( self.params.send_max_bytes, ), unmarshaler=GRPCUnmarshaler( + self.web, codec, self.params.read_max_bytes, request.stream(), @@ -395,6 +398,7 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: compression=get_compresion_from_name(self.params.compression_name, self.params.compressions), ), unmarshaler=GRPCUnmarshaler( + web=self.web, codec=self.params.codec, read_max_bytes=self.params.read_max_bytes, ), @@ -439,6 +443,24 @@ def __init__( """ super().__init__(codec, compression, compress_min_bytes, send_max_bytes) + async def marshal_web_trailers(self, trailers: Headers) -> bytes: + """Serialize HTTP trailer headers into a gRPC-Web trailer envelope. + + Args: + trailers (Headers): A dictionary-like object containing HTTP trailer headers. + + Returns: + bytes: The serialized gRPC-Web trailer envelope containing the trailer headers. + + """ + lines = [] + for key, value in trailers.items(): + lines.append(f"{key}: {value}\r\n") + + env = self.write_envelope("".join(lines).encode(), EnvelopeFlags.trailer) + + return env.encode() + class GRPCUnmarshaler(EnvelopeReader): """GRPCUnmarshaler is a specialized EnvelopeReader for handling gRPC message unmarshaling. @@ -456,8 +478,12 @@ class GRPCUnmarshaler(EnvelopeReader): """ + web: bool + _web_trailers: Headers | None + def __init__( self, + web: bool, codec: Codec | None, read_max_bytes: int, stream: AsyncIterable[bytes] | None = None, @@ -466,6 +492,7 @@ def __init__( """Initialize the protocol gRPC handler. Args: + web (bool): Indicates if the connection is for a web environment. codec (Codec | None): The codec to use for encoding/decoding messages. Can be None. read_max_bytes (int): The maximum number of bytes to read from the stream. stream (AsyncIterable[bytes] | None, optional): An asynchronous iterable stream of bytes. Defaults to None. @@ -473,8 +500,10 @@ def __init__( """ super().__init__(codec, read_max_bytes, stream, compression) + self.web = web + self._web_trailers = None - async def unmarshal(self, message: Any) -> AsyncIterator[Any]: + async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: """Asynchronously unmarshals a given message and yields each resulting object. Args: @@ -484,8 +513,46 @@ async def unmarshal(self, message: Any) -> AsyncIterator[Any]: Any: Each object obtained from unmarshaling the message. """ - async for obj, _ in super().unmarshal(message): - yield obj + async for obj, end in super().unmarshal(message): + if end: + env = self.last + if not env: + raise ConnectError("protocol error: empty envelope") + data = copy(env.data) + env.data = b"" + + if not (self.web and env.is_set(EnvelopeFlags.trailer)): + raise ConnectError( + f"protocol error: invalid envelope flags: {env.flags}", + ) + + trailers = Headers() + lines = data.decode("utf-8").splitlines() + for line in lines: + if line == "": + continue + + name, value = line.split(":", 1) + name = name.strip().lower() + value = value.strip() + if name in trailers: + trailers[name] += "," + value + else: + trailers[name] = value + + self._web_trailers = trailers + + yield obj, end + + @property + def web_trailers(self) -> Headers | None: + """Return the trailers received in the last envelope. + + Returns: + Headers | None: The trailers received in the last envelope, or None if no trailers were received. + + """ + return self._web_trailers EventHook = Callable[..., Any] @@ -588,10 +655,25 @@ def peer(self) -> Peer: async def receive(self, message: Any, abort_event: asyncio.Event | None) -> AsyncIterator[Any]: """Receives a message and processes it.""" - async for obj in self.unmarshaler.unmarshal(message): + trailer_received = False + + async for obj, end in self.unmarshaler.unmarshal(message): if abort_event and abort_event.is_set(): raise ConnectError("receive operation aborted", Code.CANCELED) + if end: + if trailer_received: + raise ConnectError("received extra end stream trailer", Code.INVALID_ARGUMENT) + + trailer_received = True + if self.unmarshaler.web_trailers is None: + raise ConnectError("trailer not received", Code.INVALID_ARGUMENT) + + continue + + if trailer_received: + raise ConnectError("protocol error: received extra message after trailer", Code.INVALID_ARGUMENT) + yield obj if callable(self.receive_trailers): @@ -613,11 +695,17 @@ async def receive(self, message: Any, abort_event: asyncio.Event | None) -> Asyn raise server_error def _receive_trailers(self, response: httpcore.Response) -> None: - if "trailing_headers" not in response.extensions: - return + if self.web: + trailers = self.unmarshaler.web_trailers + if trailers is not None: + self._response_trailers.update(trailers) + + else: + if "trailing_headers" not in response.extensions: + return - trailers = response.extensions["trailing_headers"] - self._response_trailers.update(Headers(trailers)) + trailers = response.extensions["trailing_headers"] + self._response_trailers.update(Headers(trailers)) @property def request_headers(self) -> Headers: @@ -771,6 +859,7 @@ class GRPCHandlerConn(StreamingHandlerConn): """ + web: bool _spec: Spec _peer: Peer writer: ServerResponseWriter @@ -782,6 +871,7 @@ class GRPCHandlerConn(StreamingHandlerConn): def __init__( self, + web: bool, writer: ServerResponseWriter, spec: Spec, peer: Peer, @@ -794,6 +884,7 @@ def __init__( """Initialize a new instance of the class. Args: + web (bool): Indicates if the connection is for a web environment. writer (ServerResponseWriter): The writer used to send responses to the client. spec (Spec): The specification object describing the protocol or service. peer (Peer): The peer information for the current connection. @@ -805,6 +896,7 @@ def __init__( is_streaming (bool, optional): Indicates if the connection is streaming. Defaults to False. """ + self.web = web self.writer = writer self._spec = spec self._peer = peer @@ -867,7 +959,7 @@ def peer(self) -> Peer: """ return self._peer - def receive(self, message: Any) -> AsyncIterator[Any]: + async def receive(self, message: Any) -> AsyncIterator[Any]: """Receives a message and processes it. Args: @@ -878,7 +970,8 @@ def receive(self, message: Any) -> AsyncIterator[Any]: this will yield exactly one item. """ - return self.unmarshaler.unmarshal(message) + async for obj, _ in self.unmarshaler.unmarshal(message): + yield obj @property def request_headers(self) -> Headers: @@ -901,14 +994,23 @@ async def send(self, messages: AsyncIterable[Any]) -> None: None """ - await self.writer.write( - StreamingResponse( - content=self.marshal_with_error_handling(messages), - headers=self.response_headers, - trailers=self.response_trailers, - status_code=200, + if self.web: + await self.writer.write( + StreamingResponse( + content=self._send_messages(messages), + headers=self.response_headers, + status_code=200, + ) + ) + else: + await self.writer.write( + StreamingResponse( + content=self._send_messages(messages), + headers=self.response_headers, + trailers=self.response_trailers, + status_code=200, + ) ) - ) @property def response_headers(self) -> Headers: @@ -933,14 +1035,21 @@ def response_trailers(self) -> Headers: """ return self._response_trailers - async def marshal_with_error_handling(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: - """Marshal messages to bytes with error handling. + async def _send_messages(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: + """Asynchronously sends marshaled messages and yields them as byte streams. Args: - messages (AsyncIterable[Any]): The messages to marshal + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled and sent. - Returns: - AsyncIterator[bytes]: An async iterator of marshaled bytes + Yields: + bytes: Marshaled message bytes, and optionally marshaled web trailers if in web mode. + + Raises: + ConnectError: If an error occurs during marshaling or sending messages, a ConnectError is set and handled. + + Notes: + - Errors encountered during message marshaling are converted to ConnectError and added to response trailers. + - If running in web mode (`self.web` is True), marshaled web trailers are yielded at the end. """ error: ConnectError | None = None @@ -952,6 +1061,10 @@ async def marshal_with_error_handling(self, messages: AsyncIterable[Any]) -> Asy finally: grpc_error_to_trailer(self.response_trailers, error) + if self.web: + body = await self.marshaler.marshal_web_trailers(self.response_trailers) + yield body + async def send_error(self, error: ConnectError) -> None: """Send an error response over gRPC by converting the provided ConnectError into gRPC trailers. @@ -966,15 +1079,25 @@ async def send_error(self, error: ConnectError) -> None: """ grpc_error_to_trailer(self.response_trailers, error) - - await self.writer.write( - StreamingResponse( - content=[], - headers=self.response_headers, - trailers=self.response_trailers, - status_code=200, + if self.web: + body = await self.marshaler.marshal_web_trailers(self.response_trailers) + + await self.writer.write( + StreamingResponse( + content=aiterate([body]), + headers=self.response_headers, + status_code=200, + ) + ) + else: + await self.writer.write( + StreamingResponse( + content=[], + headers=self.response_headers, + trailers=self.response_trailers, + status_code=200, + ) ) - ) def grpc_codec_from_content_type(web: bool, content_type: str) -> str: