diff --git a/src/connect/content_stream.py b/src/connect/content_stream.py index 1b93656..5644d58 100644 --- a/src/connect/content_stream.py +++ b/src/connect/content_stream.py @@ -55,7 +55,7 @@ class BoundAsyncStream(AsyncByteStream): """ - _stream: AsyncIterable[bytes] + _stream: AsyncIterable[bytes] | None _closed: bool def __init__(self, stream: AsyncIterable[bytes]) -> None: @@ -70,13 +70,19 @@ def __init__(self, stream: AsyncIterable[bytes]) -> None: async def __aiter__(self) -> AsyncIterator[bytes]: """Asynchronous iterator method to read byte chunks from the stream.""" + if self._stream is None: + return + try: with map_httpcore_exceptions(): async for chunk in self._stream: yield chunk - except BaseException as exc: - await self.aclose() - raise exc + except Exception as exc: + try: + await self.aclose() + except Exception as close_exc: + raise ExceptionGroup("Multiple errors occurred", [exc, close_exc]) from exc + raise async def aclose(self) -> None: """Asynchronously close the stream.""" @@ -84,10 +90,14 @@ async def aclose(self) -> None: return self._closed = True - with map_httpcore_exceptions(): - aclose = get_acallable_attribute(self._stream, "aclose") - if aclose: - await aclose() + try: + if self._stream is not None: + with map_httpcore_exceptions(): + aclose = get_acallable_attribute(self._stream, "aclose") + if aclose: + await aclose() + finally: + self._stream = None class AsyncDataStream[T]: @@ -102,7 +112,7 @@ class AsyncDataStream[T]: """ - _stream: AsyncIterable[T] + _stream: AsyncIterable[T] | None _aclose_func: Callable[..., Awaitable[None]] | None _closed: bool @@ -128,12 +138,18 @@ async def __aiter__(self) -> AsyncIterator[T]: Propagates any exception raised during iteration after ensuring the stream is closed. """ + if self._stream is None: + return + try: async for part in self._stream: yield part - except BaseException as exc: - await self.aclose() - raise exc + except Exception as exc: + try: + await self.aclose() + except Exception as close_exc: + raise ExceptionGroup("Multiple errors occurred", [exc, close_exc]) from exc + raise async def aclose(self) -> None: """Asynchronously closes the underlying stream. @@ -149,10 +165,13 @@ async def aclose(self) -> None: return self._closed = True - if self._aclose_func: - await self._aclose_func() - - elif hasattr(self._stream, "aclose"): - aclose = get_acallable_attribute(self._stream, "aclose") - if aclose: - await aclose() + try: + if self._aclose_func: + await self._aclose_func() + elif self._stream is not None: + aclose = get_acallable_attribute(self._stream, "aclose") + if aclose: + await aclose() + finally: + self._stream = None + self._aclose_func = None