Skip to content

Commit 69f727e

Browse files
authored
Improve connection loss handling (#19)
1 parent 2b7f2e5 commit 69f727e

File tree

3 files changed

+50
-14
lines changed

3 files changed

+50
-14
lines changed

stompman/connection.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import socket
3-
from collections.abc import AsyncGenerator, Iterator
3+
from collections.abc import AsyncGenerator, Generator, Iterator
4+
from contextlib import contextmanager
45
from dataclasses import dataclass, field
56
from typing import Protocol, Self, TypedDict, TypeVar, cast
67

@@ -9,7 +10,7 @@
910
from stompman.protocol import NEWLINE, Parser, dump_frame
1011

1112

12-
class _MultiHostHostLike(TypedDict):
13+
class MultiHostHostLike(TypedDict):
1314
username: str | None
1415
password: str | None
1516
host: str | None
@@ -24,7 +25,7 @@ class ConnectionParameters:
2425
passcode: str = field(repr=False)
2526

2627
@classmethod
27-
def from_pydantic_multihost_hosts(cls, hosts: list[_MultiHostHostLike]) -> list[Self]:
28+
def from_pydantic_multihost_hosts(cls, hosts: list[MultiHostHostLike]) -> list[Self]:
2829
"""Create connection parameters from a list of `MultiHostUrl` objects.
2930
3031
.. code-block:: python
@@ -101,16 +102,25 @@ async def connect(self) -> bool:
101102
return False
102103
return True
103104

105+
@contextmanager
106+
def _reraise_connection_lost(self, *causes: type[Exception]) -> Generator[None, None, None]:
107+
try:
108+
yield
109+
except causes as exception:
110+
raise ConnectionLostError(self.read_timeout) from exception
111+
104112
async def close(self) -> None:
105113
self.writer.close()
106-
await self.writer.wait_closed()
114+
with self._reraise_connection_lost(ConnectionError):
115+
await self.writer.wait_closed()
107116

108117
def write_heartbeat(self) -> None:
109118
return self.writer.write(NEWLINE)
110119

111120
async def write_frame(self, frame: AnyClientFrame) -> None:
112121
self.writer.write(dump_frame(frame))
113-
await self.writer.drain()
122+
with self._reraise_connection_lost(ConnectionError):
123+
await self.writer.drain()
114124

115125
async def _read_non_empty_bytes(self) -> bytes:
116126
while (
@@ -123,10 +133,8 @@ async def read_frames(self) -> AsyncGenerator[AnyServerFrame, None]:
123133
parser = Parser()
124134

125135
while True:
126-
try:
136+
with self._reraise_connection_lost(ConnectionError, TimeoutError):
127137
raw_frames = await asyncio.wait_for(self._read_non_empty_bytes(), timeout=self.read_timeout)
128-
except (TimeoutError, ConnectionError) as exception:
129-
raise ConnectionLostError(self.read_timeout) from exception
130138

131139
for frame in cast(Iterator[AnyServerFrame], parser.load_frames(raw_frames)):
132140
yield frame

stompman/frames.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class NackFrame:
123123
body: bytes = b""
124124

125125

126-
BeginHeaders = TypedDict("BeginHeaders", {"transaction": NotRequired[str], "content-length": NotRequired[str]})
126+
BeginHeaders = TypedDict("BeginHeaders", {"transaction": str, "content-length": NotRequired[str]})
127127

128128

129129
@dataclass
@@ -132,7 +132,7 @@ class BeginFrame:
132132
body: bytes = b""
133133

134134

135-
CommitHeaders = TypedDict("CommitHeaders", {"transaction": NotRequired[str], "content-length": NotRequired[str]})
135+
CommitHeaders = TypedDict("CommitHeaders", {"transaction": str, "content-length": NotRequired[str]})
136136

137137

138138
@dataclass
@@ -141,7 +141,7 @@ class CommitFrame:
141141
body: bytes = b""
142142

143143

144-
AbortHeaders = TypedDict("AbortHeaders", {"transaction": NotRequired[str], "content-length": NotRequired[str]})
144+
AbortHeaders = TypedDict("AbortHeaders", {"transaction": str, "content-length": NotRequired[str]})
145145

146146

147147
@dataclass

tests/test_connection.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import socket
23
from collections.abc import Awaitable
34
from functools import partial
45
from typing import Any
@@ -14,7 +15,7 @@
1415
ConnectionParameters,
1516
HeartbeatFrame,
1617
)
17-
from stompman.frames import CommitFrame
18+
from stompman.frames import BeginFrame, CommitFrame
1819

1920

2021
@pytest.fixture()
@@ -110,13 +111,40 @@ async def take_frames(count: int) -> list[AnyServerFrame]:
110111
assert MockWriter.write.mock_calls == [mock.call(b"\n"), mock.call(b"COMMIT\ntransaction:transaction\n\n\x00")]
111112

112113

114+
async def test_connection_close_connection_error(connection: Connection, monkeypatch: pytest.MonkeyPatch) -> None:
115+
class MockWriter:
116+
close = mock.Mock()
117+
wait_closed = mock.AsyncMock(side_effect=ConnectionError)
118+
119+
monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(return_value=(mock.Mock(), MockWriter())))
120+
await connection.connect()
121+
122+
with pytest.raises(ConnectionLostError):
123+
await connection.close()
124+
125+
126+
async def test_connection_write_frame_connection_error(connection: Connection, monkeypatch: pytest.MonkeyPatch) -> None:
127+
class MockWriter:
128+
write = mock.Mock()
129+
drain = mock.AsyncMock(side_effect=ConnectionError)
130+
131+
monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(return_value=(mock.Mock(), MockWriter())))
132+
await connection.connect()
133+
134+
with pytest.raises(ConnectionLostError):
135+
await connection.write_frame(BeginFrame(headers={"transaction": ""}))
136+
137+
113138
async def test_connection_timeout(monkeypatch: pytest.MonkeyPatch, connection: Connection) -> None:
114139
mock_wait_for(monkeypatch)
115140
assert not await connection.connect()
116141

117142

118-
async def test_connection_error(monkeypatch: pytest.MonkeyPatch, connection: Connection) -> None:
119-
monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(side_effect=ConnectionError))
143+
@pytest.mark.parametrize("exception", [BrokenPipeError, socket.gaierror])
144+
async def test_connection_connect_connection_error(
145+
monkeypatch: pytest.MonkeyPatch, connection: Connection, exception: type[Exception]
146+
) -> None:
147+
monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(side_effect=exception))
120148
assert not await connection.connect()
121149

122150

0 commit comments

Comments
 (0)