diff --git a/stompman/connection.py b/stompman/connection.py index 5fdd47a..58be58f 100644 --- a/stompman/connection.py +++ b/stompman/connection.py @@ -124,7 +124,7 @@ async def read_frames(self) -> AsyncGenerator[AnyServerFrame, None]: while True: try: raw_frames = await asyncio.wait_for(self._read_non_empty_bytes(), timeout=self.read_timeout) - except TimeoutError as exception: + except (TimeoutError, ConnectionError) as exception: raise ConnectionLostError(self.read_timeout) from exception for frame in cast(Iterator[AnyServerFrame], parser.load_frames(raw_frames)): diff --git a/tests/test_connection.py b/tests/test_connection.py index fa95ac2..c80f766 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -120,7 +120,7 @@ async def test_connection_error(monkeypatch: pytest.MonkeyPatch, connection: Con assert not await connection.connect() -async def test_read_timeout(monkeypatch: pytest.MonkeyPatch, connection: Connection) -> None: +async def test_read_frames_timeout_error(monkeypatch: pytest.MonkeyPatch, connection: Connection) -> None: monkeypatch.setattr( "asyncio.open_connection", mock.AsyncMock(return_value=(mock.AsyncMock(read=partial(asyncio.sleep, 5)), mock.AsyncMock())), @@ -129,3 +129,15 @@ async def test_read_timeout(monkeypatch: pytest.MonkeyPatch, connection: Connect mock_wait_for(monkeypatch) with pytest.raises(ConnectionLostError): [frame async for frame in connection.read_frames()] + + +async def test_read_frames_connection_error(monkeypatch: pytest.MonkeyPatch, connection: Connection) -> None: + monkeypatch.setattr( + "asyncio.open_connection", + mock.AsyncMock( + return_value=(mock.AsyncMock(read=mock.AsyncMock(side_effect=BrokenPipeError)), mock.AsyncMock()) + ), + ) + await connection.connect() + with pytest.raises(ConnectionLostError): + [frame async for frame in connection.read_frames()]