Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions src/connect/content_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class BoundAsyncStream(AsyncByteStream):

"""

_stream: AsyncIterable[bytes]
_stream: AsyncIterable[bytes] | None
_closed: bool

def __init__(self, stream: AsyncIterable[bytes]) -> None:
Expand All @@ -70,24 +70,34 @@ def __init__(self, stream: AsyncIterable[bytes]) -> None:

async def __aiter__(self) -> AsyncIterator[bytes]:
"""Asynchronous iterator method to read byte chunks from the stream."""
if self._stream is None:
return

try:
with map_httpcore_exceptions():
async for chunk in self._stream:
yield chunk
except BaseException as exc:
await self.aclose()
raise exc
except Exception as exc:
try:
await self.aclose()
except Exception as close_exc:
raise ExceptionGroup("Multiple errors occurred", [exc, close_exc]) from exc
raise

async def aclose(self) -> None:
"""Asynchronously close the stream."""
if self._closed:
return

self._closed = True
with map_httpcore_exceptions():
aclose = get_acallable_attribute(self._stream, "aclose")
if aclose:
await aclose()
try:
if self._stream is not None:
with map_httpcore_exceptions():
aclose = get_acallable_attribute(self._stream, "aclose")
if aclose:
await aclose()
finally:
self._stream = None


class AsyncDataStream[T]:
Expand All @@ -102,7 +112,7 @@ class AsyncDataStream[T]:

"""

_stream: AsyncIterable[T]
_stream: AsyncIterable[T] | None
_aclose_func: Callable[..., Awaitable[None]] | None
_closed: bool

Expand All @@ -128,12 +138,18 @@ async def __aiter__(self) -> AsyncIterator[T]:
Propagates any exception raised during iteration after ensuring the stream is closed.

"""
if self._stream is None:
return

try:
async for part in self._stream:
yield part
except BaseException as exc:
await self.aclose()
raise exc
except Exception as exc:
try:
await self.aclose()
except Exception as close_exc:
raise ExceptionGroup("Multiple errors occurred", [exc, close_exc]) from exc
raise

async def aclose(self) -> None:
"""Asynchronously closes the underlying stream.
Expand All @@ -149,10 +165,13 @@ async def aclose(self) -> None:
return

self._closed = True
if self._aclose_func:
await self._aclose_func()

elif hasattr(self._stream, "aclose"):
aclose = get_acallable_attribute(self._stream, "aclose")
if aclose:
await aclose()
try:
if self._aclose_func:
await self._aclose_func()
elif self._stream is not None:
aclose = get_acallable_attribute(self._stream, "aclose")
if aclose:
await aclose()
finally:
self._stream = None
self._aclose_func = None
Loading