Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
77 changes: 77 additions & 0 deletions src/connect/asgi_helpers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Utility functions for ASGI helpers, including request/response decorators and route path extraction."""

import typing
from collections.abc import (
Awaitable,
Callable,
)

from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp, Receive, Scope, Send

from connect.utils import is_async_callable, run_in_threadpool


def request_response(func: Callable[[Request], Awaitable[Response] | Response]) -> ASGIApp:
"""Convert a request handler function into an ASGI application.

This decorator takes a function that handles a request and returns a response,
and wraps it into an ASGI application callable. The handler function can be either
synchronous or asynchronous.

Args:
func (Callable[[Request], Awaitable[Response] | Response]): The request handler function.
It can be a synchronous function returning a Response or an asynchronous function
returning an Awaitable of Response.

Returns:
ASGIApp: An ASGI application callable that can be used to handle ASGI requests.

"""

async def async_func(request: Request) -> Response:
if is_async_callable(func):
return await func(request)
else:
return typing.cast(Response, await run_in_threadpool(func, request))

f: Callable[[Request], Awaitable[Response]] = async_func

async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive, send)
response = await f(request)
await response(scope, receive, send)

return app


def get_route_path(scope: Scope) -> str:
"""Extract the route path from the given scope.

Args:
scope (Scope): The scope dictionary containing the request information.

Returns:
str: The extracted route path. If a root path is specified in the scope,
the function returns the path relative to the root path. If the path
does not start with the root path or if the path is equal to the root
path, the function returns the original path or an empty string,
respectively.

"""
path: str = scope["path"]
root_path = scope.get("root_path", "")
if not root_path:
return path

if not path.startswith(root_path):
return path

if path == root_path:
return ""

if path[len(root_path)] == "/":
return path[len(root_path) :]

return path
56 changes: 0 additions & 56 deletions src/connect/byte_stream.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from connect.interceptor import apply_interceptors
from connect.options import ClientOptions
from connect.protocol import Protocol, ProtocolClient, ProtocolClientParams
from connect.protocol_connect import ProtocolConnect
from connect.protocol_grpc import ProtocolGRPC
from connect.protocol_connect.connect_protocol import ProtocolConnect
from connect.protocol_grpc.grpc_protocol import ProtocolGRPC
from connect.session import AsyncClientSession
from connect.utils import aiterate

Expand Down
3 changes: 2 additions & 1 deletion src/connect/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from pydantic import BaseModel

from connect.code import Code
from connect.content_stream import AsyncDataStream
from connect.error import ConnectError
from connect.headers import Headers
from connect.idempotency_level import IdempotencyLevel
from connect.utils import AsyncDataStream, aiterate, get_acallable_attribute, get_callable_attribute
from connect.utils import aiterate, get_acallable_attribute, get_callable_attribute


class StreamType(Enum):
Expand Down
158 changes: 158 additions & 0 deletions src/connect/content_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Asynchronous byte stream utilities for HTTP core response handling."""

from collections.abc import (
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
)

from connect.utils import (
get_acallable_attribute,
map_httpcore_exceptions,
)


class AsyncByteStream(AsyncIterable[bytes]):
"""An abstract base class for asynchronous byte streams.

This class defines the interface for an asynchronous byte stream, which
includes methods for iterating over the stream and closing it.

"""

async def __aiter__(self) -> AsyncIterator[bytes]:
"""Asynchronous iterator method.

This method should be implemented to provide asynchronous iteration
over the object. It must return an asynchronous iterator that yields
bytes.

Raises:
NotImplementedError: If the method is not implemented.

"""
raise NotImplementedError("The '__aiter__' method must be implemented.") # pragma: no cover
yield b""

async def aclose(self) -> None:
"""Asynchronously close the byte stream."""
pass


class BoundAsyncStream(AsyncByteStream):
"""An asynchronous byte stream wrapper that binds to an existing async iterable of bytes.

This class provides an asynchronous iterator interface for reading byte chunks from the given stream,
and ensures proper resource cleanup by closing the underlying stream when needed.

Args:
stream (AsyncIterable[bytes]): The asynchronous iterable byte stream to wrap.

Attributes:
stream (AsyncIterable[bytes]): The wrapped asynchronous byte stream.
_closed (bool): Indicates whether the stream has been closed.

"""

_stream: AsyncIterable[bytes]
_closed: bool

def __init__(self, stream: AsyncIterable[bytes]) -> None:
"""Initialize the object with an asynchronous iterable stream of bytes.

Args:
stream (AsyncIterable[bytes]): An asynchronous iterable that yields bytes.

"""
self._stream = stream
self._closed = False

async def __aiter__(self) -> AsyncIterator[bytes]:
"""Asynchronous iterator method to read byte chunks from the stream."""
try:
with map_httpcore_exceptions():
async for chunk in self._stream:
yield chunk
except BaseException as exc:
await self.aclose()
raise exc

async def aclose(self) -> None:
"""Asynchronously close the stream."""
if self._closed:
return

self._closed = True
with map_httpcore_exceptions():
aclose = get_acallable_attribute(self._stream, "aclose")
if aclose:
await aclose()


class AsyncDataStream[T]:
"""An asynchronous data stream wrapper that provides iteration and cleanup functionality.

Type Parameters:
T: The type of items yielded by the stream.

Attributes:
_stream (AsyncIterable[T]): The underlying asynchronous iterable data stream.
aclose_func (Callable[..., Awaitable[None]] | None): Optional asynchronous cleanup function to be called on close.

"""

_stream: AsyncIterable[T]
_aclose_func: Callable[..., Awaitable[None]] | None
_closed: bool

def __init__(self, stream: AsyncIterable[T], aclose_func: Callable[..., Awaitable[None]] | None = None) -> None:
"""Initialize the object with an asynchronous iterable stream and an optional asynchronous close function.

Args:
stream (AsyncIterable[T]): The asynchronous iterable stream to be wrapped.
aclose_func (Callable[..., Awaitable[None]], optional): An optional asynchronous function to be called when closing the stream. Defaults to None.

"""
self._stream = stream
self._aclose_func = aclose_func
self._closed = False

async def __aiter__(self) -> AsyncIterator[T]:
"""Asynchronously iterates over the underlying stream, yielding each part.

Yields:
T: The next part from the stream.

Raises:
Propagates any exception raised during iteration after ensuring the stream is closed.

"""
try:
async for part in self._stream:
yield part
except BaseException as exc:
await self.aclose()
raise exc

async def aclose(self) -> None:
"""Asynchronously closes the underlying stream.

If a custom asynchronous close function (`aclose_func`) is provided, it is awaited.
Otherwise, if the underlying stream has an `aclose` method, it is retrieved and awaited.

Raises:
Any exception raised by the custom close function or the stream's `aclose` method.

"""
if self._closed:
return

self.closed = True
if self._aclose_func:
await self._aclose_func()

elif hasattr(self._stream, "aclose"):
aclose = get_acallable_attribute(self._stream, "aclose")
if aclose:
await aclose()
9 changes: 4 additions & 5 deletions src/connect/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any

import anyio
from starlette.responses import PlainTextResponse
from starlette.responses import PlainTextResponse, Response

from connect.code import Code
from connect.codec import Codec, CodecMap, CodecNameType, ProtoBinaryCodec, ProtoJSONCodec
Expand Down Expand Up @@ -37,14 +37,13 @@
sorted_accept_post_value,
sorted_allow_method_value,
)
from connect.protocol_connect import (
from connect.protocol_connect.connect_protocol import (
ProtocolConnect,
)
from connect.protocol_grpc import ProtocolGRPC
from connect.protocol_grpc.grpc_protocol import ProtocolGRPC
from connect.request import Request
from connect.response import Response
from connect.response_writer import ServerResponseWriter
from connect.utils import aiterate
from connect.writer import ServerResponseWriter

type UnaryFunc[T_Request, T_Response] = Callable[[UnaryRequest[T_Request]], Awaitable[UnaryResponse[T_Response]]]
type StreamFunc[T_Request, T_Response] = Callable[[StreamRequest[T_Request]], Awaitable[StreamResponse[T_Response]]]
Expand Down
2 changes: 1 addition & 1 deletion src/connect/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from starlette.responses import Response
from starlette.types import ASGIApp, Receive, Scope, Send

from connect.asgi_helpers.utils import get_route_path, request_response
from connect.handler import Handler
from connect.utils import get_route_path, request_response

HandleFunc = Callable[[Request], Awaitable[Response]]

Expand Down
2 changes: 1 addition & 1 deletion src/connect/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from connect.headers import Headers
from connect.idempotency_level import IdempotencyLevel
from connect.request import Request
from connect.response_writer import ServerResponseWriter
from connect.session import AsyncClientSession
from connect.writer import ServerResponseWriter

PROTOCOL_CONNECT = "connect"
PROTOCOL_GRPC = "grpc"
Expand Down
Loading
Loading