diff --git a/stompman/connection.py b/stompman/connection.py index 901a9b6..5fdd47a 100644 --- a/stompman/connection.py +++ b/stompman/connection.py @@ -1,13 +1,20 @@ import asyncio from collections.abc import AsyncGenerator, Iterator from dataclasses import dataclass, field -from typing import Protocol, TypeVar, cast +from typing import Protocol, Self, TypedDict, TypeVar, cast from stompman.errors import ConnectionLostError from stompman.frames import AnyClientFrame, AnyServerFrame from stompman.protocol import NEWLINE, Parser, dump_frame +class _MultiHostHostLike(TypedDict): + username: str | None + password: str | None + host: str | None + port: int | None + + @dataclass class ConnectionParameters: host: str @@ -15,6 +22,41 @@ class ConnectionParameters: login: str passcode: str = field(repr=False) + @classmethod + def from_pydantic_multihost_hosts(cls, hosts: list[_MultiHostHostLike]) -> list[Self]: + """Create connection parameters from a list of `MultiHostUrl` objects. + + .. code-block:: python + import stompman. + + ArtemisDsn = typing.Annotated[ + pydantic_core.MultiHostUrl, + pydantic.UrlConstraints( + host_required=True, + allowed_schemes=["tcp"], + ), + ] + + async with stompman.Client( + servers=stompman.ConnectionParameters.from_pydantic_multihost_hosts( + ArtemisDsn("tcp://lev:pass@host1:61616,host1:61617,host2:61616").hosts() + ), + ): + ... + """ + servers: list[Self] = [] + for host in hosts: + if host["host"] is None: + raise ValueError("host must be set") + if host["port"] is None: + raise ValueError("port must be set") + if host["username"] is None: + raise ValueError("username must be set") + if host["password"] is None: + raise ValueError("password must be set") + servers.append(cls(host=host["host"], port=host["port"], login=host["username"], passcode=host["password"])) + return servers + FrameT = TypeVar("FrameT", bound=AnyClientFrame | AnyServerFrame) diff --git a/tests/test_connection.py b/tests/test_connection.py index 0f4d9bb..fa95ac2 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -35,6 +35,18 @@ async def mock_impl(future: Awaitable[Any], timeout: int) -> Any: # noqa: ANN40 monkeypatch.setattr("asyncio.wait_for", mock_impl) +def test_connection_parameters_from_pydantic_multihost_hosts() -> None: + full_host: dict[str, Any] = {"username": "me", "password": "pass", "host": "localhost", "port": 1234} + assert ConnectionParameters.from_pydantic_multihost_hosts([{**full_host, "port": index} for index in range(5)]) == [ # type: ignore[typeddict-item] + ConnectionParameters(full_host["host"], index, full_host["username"], full_host["password"]) + for index in range(5) + ] + + for key in ("username", "password", "host", "port"): + with pytest.raises(ValueError, match=f"{key} must be set"): + assert ConnectionParameters.from_pydantic_multihost_hosts([{**full_host, key: None}, full_host]) # type: ignore[typeddict-item, list-item] + + async def test_connection_lifespan(connection: Connection, monkeypatch: pytest.MonkeyPatch) -> None: class MockWriter: close = mock.Mock()