diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index de84a70..f16caa7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,7 @@ jobs: - uses: extractions/setup-just@v2 - run: just install check-types - lint-format: + lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -45,9 +45,9 @@ jobs: path: | ~/.cache/pip ~/.cache/pypoetry - key: lint-format-${{ hashFiles('pyproject.toml') }} + key: lint-${{ hashFiles('pyproject.toml') }} - uses: extractions/setup-just@v2 - - run: just install lint-format + - run: just install lint test: runs-on: ubuntu-latest diff --git a/Justfile b/Justfile index 1d67556..fd350a4 100644 --- a/Justfile +++ b/Justfile @@ -1,4 +1,4 @@ -default: install lint-format check-types test +default: install lint check-types test install: poetry install --sync @@ -6,7 +6,7 @@ install: test *args: poetry run pytest {{args}} -lint-format: +lint: poetry run ruff check . poetry run ruff format . diff --git a/README.md b/README.md index 3c8732b..1f52520 100644 --- a/README.md +++ b/README.md @@ -111,9 +111,18 @@ stompman takes care of cleaning up resources automatically. When you leave the c ### Handling Connectivity Issues -- If multiple servers are provided, stompman will attempt to connect to each one simultaneously and use the first that succeeds. -- If all servers fail to connect, an `stompman.FailedAllConnectAttemptsError` will be raised. There're no need to handle it, if you should tune retry and timeout parameters to your needs. -- If a connection is lost, a `stompman.ReadTimeoutError` will be raised. You'll need to implement reconnect logic manually. Implementing reconnect logic in the library would be too complex, since there're no global state and clean-ups are automatic (e.g. it won't be possible to re-subscribe to destination because client doesn't keep track of subscriptions). +- If multiple servers were provided, stompman will attempt to connect to each one simultaneously and will use the first that succeeds. + +- If all servers fail to connect, an `stompman.FailedAllConnectAttemptsError` will be raised. In normal situation it doesn't need to be handled: tune retry and timeout parameters in `stompman.Client()` to your needs. + +- If a connection is lost, a `stompman.ConnectionLostError` will be raised. You should implement reconnect logic manually, for example, with stamina: + + ```python + for attempt in stamina.retry_context(on=stompman.ConnectionLostError): + with attempt: + async with stompman.Client(...) as client: + ... + ``` ### ...and caveats diff --git a/pyproject.toml b/pyproject.toml index f00134b..6ec235c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,9 @@ target-version = "py311" fix = true unsafe-fixes = true line-length = 120 + [tool.ruff.lint] +preview = true select = ["ALL"] ignore = [ "EM", @@ -58,6 +60,7 @@ ignore = [ "ISC001", "S101", "SLF001", + "CPY001", ] [tool.pytest.ini_options] diff --git a/stompman/__init__.py b/stompman/__init__.py index c8dfa2c..e64111e 100644 --- a/stompman/__init__.py +++ b/stompman/__init__.py @@ -1,19 +1,18 @@ from stompman.client import Client, Heartbeat from stompman.connection import AbstractConnection, Connection, ConnectionParameters from stompman.errors import ( - ConnectError, ConnectionConfirmationTimeoutError, + ConnectionLostError, Error, FailedAllConnectAttemptsError, - ReadTimeoutError, UnsupportedProtocolVersionError, ) from stompman.frames import ( AbortFrame, AckFrame, - AnyFrame, + AnyClientFrame, + AnyServerFrame, BeginFrame, - ClientFrame, CommitFrame, ConnectedFrame, ConnectFrame, @@ -24,7 +23,6 @@ NackFrame, ReceiptFrame, SendFrame, - ServerFrame, SubscribeFrame, UnsubscribeFrame, ) @@ -34,34 +32,31 @@ "AbortFrame", "AbstractConnection", "AckFrame", - "FailedAllConnectAttemptsError", - "AnyFrame", + "AnyClientFrame", "AnyListeningEvent", - "BaseListenEvent", + "AnyServerFrame", "BeginFrame", - "ClientFrame", + "Client", "CommitFrame", - "ConnectedFrame", - "ConnectError", "ConnectFrame", + "ConnectedFrame", + "Connection", "ConnectionConfirmationTimeoutError", + "ConnectionLostError", "ConnectionParameters", "DisconnectFrame", + "Error", "ErrorEvent", "ErrorFrame", + "FailedAllConnectAttemptsError", "Heartbeat", "HeartbeatEvent", "HeartbeatFrame", "MessageEvent", "MessageFrame", "NackFrame", - "ReadTimeoutError", "ReceiptFrame", "SendFrame", - "ServerFrame", - "Client", - "Connection", - "Error", "SubscribeFrame", "UnsubscribeFrame", "UnsupportedProtocolVersionError", diff --git a/stompman/client.py b/stompman/client.py index 922c9cc..68ce490 100644 --- a/stompman/client.py +++ b/stompman/client.py @@ -8,7 +8,6 @@ from stompman.connection import AbstractConnection, Connection, ConnectionParameters from stompman.errors import ( - ConnectError, ConnectionConfirmationTimeoutError, FailedAllConnectAttemptsError, UnsupportedProtocolVersionError, @@ -80,12 +79,9 @@ async def _connect_to_one_server(self, server: ConnectionParameters) -> Abstract read_timeout=self.read_timeout, read_max_chunk_size=self.read_max_chunk_size, ) - try: - await connection.connect() - except ConnectError: - await asyncio.sleep(self.connect_retry_interval * (attempt + 1)) - else: + if await connection.connect(): return connection + await asyncio.sleep(self.connect_retry_interval * (attempt + 1)) return None async def _connect_to_any_server(self) -> AbstractConnection: diff --git a/stompman/connection.py b/stompman/connection.py index c693317..901a9b6 100644 --- a/stompman/connection.py +++ b/stompman/connection.py @@ -3,8 +3,8 @@ from dataclasses import dataclass, field from typing import Protocol, TypeVar, cast -from stompman.errors import ConnectError, ReadTimeoutError -from stompman.frames import AnyRealFrame, ClientFrame, ServerFrame +from stompman.errors import ConnectionLostError +from stompman.frames import AnyClientFrame, AnyServerFrame from stompman.protocol import NEWLINE, Parser, dump_frame @@ -16,7 +16,7 @@ class ConnectionParameters: passcode: str = field(repr=False) -FrameT = TypeVar("FrameT", bound=AnyRealFrame) +FrameT = TypeVar("FrameT", bound=AnyClientFrame | AnyServerFrame) @dataclass @@ -26,11 +26,11 @@ class AbstractConnection(Protocol): read_timeout: int read_max_chunk_size: int - async def connect(self) -> None: ... + async def connect(self) -> bool: ... async def close(self) -> None: ... def write_heartbeat(self) -> None: ... - async def write_frame(self, frame: ClientFrame) -> None: ... - def read_frames(self) -> AsyncGenerator[ServerFrame, None]: ... + async def write_frame(self, frame: AnyClientFrame) -> None: ... + def read_frames(self) -> AsyncGenerator[AnyServerFrame, None]: ... async def read_frame_of_type(self, type_: type[FrameT]) -> FrameT: while True: @@ -48,14 +48,15 @@ class Connection(AbstractConnection): reader: asyncio.StreamReader = field(init=False) writer: asyncio.StreamWriter = field(init=False) - async def connect(self) -> None: + async def connect(self) -> bool: try: self.reader, self.writer = await asyncio.wait_for( asyncio.open_connection(self.connection_parameters.host, self.connection_parameters.port), timeout=self.connect_timeout, ) - except (TimeoutError, ConnectionError) as exception: - raise ConnectError(self.connection_parameters) from exception + except (TimeoutError, ConnectionError): + return False + return True async def close(self) -> None: self.writer.close() @@ -64,7 +65,7 @@ async def close(self) -> None: def write_heartbeat(self) -> None: return self.writer.write(NEWLINE) - async def write_frame(self, frame: ClientFrame) -> None: + async def write_frame(self, frame: AnyClientFrame) -> None: self.writer.write(dump_frame(frame)) await self.writer.drain() @@ -75,14 +76,14 @@ async def _read_non_empty_bytes(self) -> bytes: await asyncio.sleep(0) return chunk - async def read_frames(self) -> AsyncGenerator[ServerFrame, None]: + async def read_frames(self) -> AsyncGenerator[AnyServerFrame, None]: parser = Parser() while True: try: raw_frames = await asyncio.wait_for(self._read_non_empty_bytes(), timeout=self.read_timeout) except TimeoutError as exception: - raise ReadTimeoutError(self.read_timeout) from exception + raise ConnectionLostError(self.read_timeout) from exception - for frame in cast(Iterator[ServerFrame], parser.load_frames(raw_frames)): + for frame in cast(Iterator[AnyServerFrame], parser.load_frames(raw_frames)): yield frame diff --git a/stompman/errors.py b/stompman/errors.py index f816696..e2b6d3f 100644 --- a/stompman/errors.py +++ b/stompman/errors.py @@ -22,11 +22,6 @@ class UnsupportedProtocolVersionError(Error): supported_version: str -@dataclass -class ConnectError(Error): - connection_parameters: "ConnectionParameters" - - @dataclass class FailedAllConnectAttemptsError(Error): servers: list["ConnectionParameters"] @@ -36,5 +31,5 @@ class FailedAllConnectAttemptsError(Error): @dataclass -class ReadTimeoutError(Error): +class ConnectionLostError(Error): timeout: int diff --git a/stompman/frames.py b/stompman/frames.py index 05d4e79..b6cbea3 100644 --- a/stompman/frames.py +++ b/stompman/frames.py @@ -244,7 +244,7 @@ class HeartbeatFrame: ... } FRAMES_TO_COMMANDS = {value: key for key, value in COMMANDS_TO_FRAMES.items()} -ClientFrame = ( +AnyClientFrame = ( SendFrame | SubscribeFrame | UnsubscribeFrame @@ -257,6 +257,4 @@ class HeartbeatFrame: ... | ConnectFrame | StompFrame ) -ServerFrame = ConnectedFrame | MessageFrame | ReceiptFrame | ErrorFrame -AnyRealFrame = ClientFrame | ServerFrame -AnyFrame = AnyRealFrame | HeartbeatFrame +AnyServerFrame = ConnectedFrame | MessageFrame | ReceiptFrame | ErrorFrame diff --git a/stompman/protocol.py b/stompman/protocol.py index 6bf4b1b..cac2292 100644 --- a/stompman/protocol.py +++ b/stompman/protocol.py @@ -7,8 +7,8 @@ from stompman.frames import ( COMMANDS_TO_FRAMES, FRAMES_TO_COMMANDS, - AnyFrame, - AnyRealFrame, + AnyClientFrame, + AnyServerFrame, HeartbeatFrame, ) @@ -37,7 +37,7 @@ def dump_header(key: str, value: str) -> bytes: return f"{escaped_key}:{escaped_value}\n".encode() -def dump_frame(frame: AnyRealFrame) -> bytes: +def dump_frame(frame: AnyClientFrame | AnyServerFrame) -> bytes: lines = ( FRAMES_TO_COMMANDS[type(frame)], NEWLINE, @@ -79,7 +79,7 @@ def parse_headers(buffer: list[bytes]) -> tuple[str, str] | None: return (b"".join(key_buffer).decode(), b"".join(value_buffer).decode()) if key_parsed else None -def parse_lines_into_frame(lines: deque[list[bytes]]) -> AnyFrame | None: +def parse_lines_into_frame(lines: deque[list[bytes]]) -> AnyClientFrame | AnyServerFrame | None: command = b"".join(lines.popleft()) headers = {} @@ -101,7 +101,7 @@ class Parser: _previous_byte: bytes = field(default=b"", init=False) _headers_processed: bool = field(default=False, init=False) - def load_frames(self, raw_frames: bytes) -> Iterator[AnyFrame]: + def load_frames(self, raw_frames: bytes) -> Iterator[AnyClientFrame | AnyServerFrame | HeartbeatFrame]: buffer = deque(struct.unpack(f"{len(raw_frames)!s}c", raw_frames)) while buffer: byte = buffer.popleft() diff --git a/tests/test_client.py b/tests/test_client.py index c3d21d1..00b11a8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,13 +12,12 @@ AbortFrame, AbstractConnection, AckFrame, - AnyFrame, + AnyClientFrame, + AnyServerFrame, BeginFrame, Client, - ClientFrame, CommitFrame, ConnectedFrame, - ConnectError, ConnectFrame, ConnectionConfirmationTimeoutError, ConnectionParameters, @@ -33,7 +32,6 @@ NackFrame, ReceiptFrame, SendFrame, - ServerFrame, SubscribeFrame, UnsubscribeFrame, UnsupportedProtocolVersionError, @@ -48,34 +46,36 @@ class BaseMockConnection(AbstractConnection): read_timeout: int read_max_chunk_size: int - async def connect(self) -> None: ... + async def connect(self) -> bool: # noqa: PLR6301 + return True + async def close(self) -> None: ... def write_heartbeat(self) -> None: ... - async def write_frame(self, frame: ClientFrame) -> None: ... - async def read_frames(self) -> AsyncGenerator[ServerFrame, None]: # pragma: no cover + async def write_frame(self, frame: AnyClientFrame) -> None: ... + async def read_frames(self) -> AsyncGenerator[AnyServerFrame, None]: # pragma: no cover # noqa: PLR6301 await asyncio.Future() yield # type: ignore[misc] def create_spying_connection( - read_frames_yields: list[list[ServerFrame]], -) -> tuple[type[AbstractConnection], list[AnyFrame]]: + read_frames_yields: list[list[AnyServerFrame]], +) -> tuple[type[AbstractConnection], list[AnyClientFrame | AnyServerFrame | HeartbeatFrame]]: @dataclass class BaseCollectingConnection(BaseMockConnection): - async def write_frame(self, frame: ClientFrame) -> None: + async def write_frame(self, frame: AnyClientFrame) -> None: # noqa: PLR6301 collected_frames.append(frame) - async def read_frames(self) -> AsyncGenerator[ServerFrame, None]: + async def read_frames(self) -> AsyncGenerator[AnyServerFrame, None]: # noqa: PLR6301 for frame in next(read_frames_iterator): collected_frames.append(frame) yield frame read_frames_iterator = iter(read_frames_yields) - collected_frames: list[AnyFrame] = [] + collected_frames: list[AnyClientFrame | AnyServerFrame | HeartbeatFrame] = [] return BaseCollectingConnection, collected_frames -def get_read_frames_with_lifespan(read_frames: list[list[ServerFrame]]) -> list[list[ServerFrame]]: +def get_read_frames_with_lifespan(read_frames: list[list[AnyServerFrame]]) -> list[list[AnyServerFrame]]: return [ [ConnectedFrame(headers={"version": PROTOCOL_VERSION, "heart-beat": "1,1"})], *read_frames, @@ -83,7 +83,10 @@ def get_read_frames_with_lifespan(read_frames: list[list[ServerFrame]]) -> list[ ] -def assert_frames_between_lifespan_match(collected_frames: list[AnyFrame], expected_frames: list[AnyFrame]) -> None: +def assert_frames_between_lifespan_match( + collected_frames: list[AnyClientFrame | AnyServerFrame | HeartbeatFrame], + expected_frames: list[AnyClientFrame | AnyServerFrame | HeartbeatFrame], +) -> None: assert collected_frames[2:-2] == expected_frames @@ -104,13 +107,11 @@ async def test_client_connect_to_one_server_ok(ok_on_attempt: int, monkeypatch: attempts = 0 class MockConnection(BaseMockConnection): - async def connect(self) -> None: + async def connect(self) -> bool: assert self.connection_parameters == client.servers[0] - nonlocal attempts attempts += 1 - if attempts != ok_on_attempt: - raise ConnectError(client.servers[0]) + return attempts == ok_on_attempt sleep_mock = mock.AsyncMock() monkeypatch.setattr("asyncio.sleep", sleep_mock) @@ -122,8 +123,8 @@ async def connect(self) -> None: @pytest.mark.usefixtures("mock_sleep") async def test_client_connect_to_one_server_fails() -> None: class MockConnection(BaseMockConnection): - async def connect(self) -> None: - raise ConnectError(client.servers[0]) + async def connect(self) -> bool: # noqa: PLR6301 + return False client = EnrichedClient(connection_class=MockConnection) assert await client._connect_to_one_server(client.servers[0]) is None @@ -132,9 +133,8 @@ async def connect(self) -> None: @pytest.mark.usefixtures("mock_sleep") async def test_client_connect_to_any_server_ok() -> None: class MockConnection(BaseMockConnection): - async def connect(self) -> None: - if self.connection_parameters.port != successful_server.port: - raise ConnectError(self.connection_parameters) + async def connect(self) -> bool: + return self.connection_parameters.port == successful_server.port successful_server = ConnectionParameters("localhost", 10, "login", "pass") client = EnrichedClient( @@ -153,8 +153,8 @@ async def connect(self) -> None: @pytest.mark.usefixtures("mock_sleep") async def test_client_connect_to_any_server_fails() -> None: class MockConnection(BaseMockConnection): - async def connect(self) -> None: - raise ConnectError(client.servers[0]) + async def connect(self) -> bool: # noqa: PLR6301 + return False client = EnrichedClient( servers=[ @@ -216,7 +216,7 @@ async def timeout(future: Awaitable[Any], timeout: float) -> Any: # noqa: ANN40 client = EnrichedClient(connection_class=BaseMockConnection) with pytest.raises(ConnectionConfirmationTimeoutError) as exc_info: - await client.__aenter__() + await client.__aenter__() # noqa: PLC2801 assert exc_info.value == ConnectionConfirmationTimeoutError(client.connection_confirmation_timeout) @@ -229,7 +229,7 @@ async def test_client_lifespan_unsupported_protocol_version() -> None: client = EnrichedClient(connection_class=connection_class) with pytest.raises(UnsupportedProtocolVersionError) as exc_info: - await client.__aenter__() + await client.__aenter__() # noqa: PLC2801 assert exc_info.value == UnsupportedProtocolVersionError( given_version=given_version, supported_version=PROTOCOL_VERSION diff --git a/tests/test_connection.py b/tests/test_connection.py index 99597c0..0f4d9bb 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -7,13 +7,12 @@ import pytest from stompman import ( + AnyServerFrame, ConnectedFrame, - ConnectError, Connection, + ConnectionLostError, ConnectionParameters, HeartbeatFrame, - ReadTimeoutError, - ServerFrame, ) from stompman.frames import CommitFrame @@ -74,7 +73,7 @@ class MockReader: connection.write_heartbeat() await connection.write_frame(CommitFrame(headers={"transaction": "transaction"})) - async def take_frames(count: int) -> list[ServerFrame]: + async def take_frames(count: int) -> list[AnyServerFrame]: frames = [] async for frame in connection.read_frames(): frames.append(frame) @@ -101,14 +100,12 @@ async def take_frames(count: int) -> list[ServerFrame]: async def test_connection_timeout(monkeypatch: pytest.MonkeyPatch, connection: Connection) -> None: mock_wait_for(monkeypatch) - with pytest.raises(ConnectError): - await connection.connect() + assert not await connection.connect() async def test_connection_error(monkeypatch: pytest.MonkeyPatch, connection: Connection) -> None: monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(side_effect=ConnectionError)) - with pytest.raises(ConnectError): - await connection.connect() + assert not await connection.connect() async def test_read_timeout(monkeypatch: pytest.MonkeyPatch, connection: Connection) -> None: @@ -118,5 +115,5 @@ async def test_read_timeout(monkeypatch: pytest.MonkeyPatch, connection: Connect ) await connection.connect() mock_wait_for(monkeypatch) - with pytest.raises(ReadTimeoutError): + with pytest.raises(ConnectionLostError): [frame async for frame in connection.read_frames()] diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 81debd8..c2abed2 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -7,7 +7,7 @@ HeartbeatFrame, MessageFrame, ) -from stompman.frames import AckFrame, ClientFrame, ServerFrame +from stompman.frames import AckFrame, AnyClientFrame, AnyServerFrame from stompman.protocol import Parser, dump_frame @@ -37,7 +37,7 @@ ), ], ) -def test_dump_frame(frame: ClientFrame, dumped_frame: bytes) -> None: +def test_dump_frame(frame: AnyClientFrame, dumped_frame: bytes) -> None: assert dump_frame(frame) == dumped_frame @@ -223,5 +223,5 @@ def test_dump_frame(frame: ClientFrame, dumped_frame: bytes) -> None: (b"SOME_COMMAND\nhead:\nheader:1.1\n\n\x00", []), ], ) -def test_load_frames(raw_frames: bytes, loaded_frames: list[ServerFrame]) -> None: +def test_load_frames(raw_frames: bytes, loaded_frames: list[AnyServerFrame]) -> None: assert list(Parser().load_frames(raw_frames)) == loaded_frames