diff --git a/async-usage.py b/async-usage.py new file mode 100644 index 00000000..880ddd00 --- /dev/null +++ b/async-usage.py @@ -0,0 +1,55 @@ +from realtime.connection import Socket +import asyncio +import uuid + +def callback1(payload): + print(f"c1: {payload}") + +def callback2(payload): + print(f"c2: {payload}") + + +async def main(): + + TOKEN = "" + URLsink = f"ws://127.0.0.1:4000/socket/websocket?token={TOKEN}&vsn=2.0.0" + + client = Socket(URLsink) + + await client.connect() + + # fire and forget the listening routine + listen_task = asyncio.ensure_future(client.listen()) + + channel_s = client.set_channel("yourchannel") + await channel_s.join() + channel_s.on("test_event", None, callback1) + + # non sense elixir handler, we would not have an event on a reply + #def handle_in("request_ping", payload, socket) do + # push(socket, "test_event", %{body: payload}) + # {:noreply, socket} + #end + + await channel_s.send("request_ping", "this is my payload 1", None) + await channel_s.send("request_ping", "this is my payload 2", None) + await channel_s.send("request_ping", "this is my payload 3", None) + + # proper relpy elixir handler + #def handle_in("ping", payload, socket) do + # {:reply, {:ok, payload}, socket} + #end + + ref = str(uuid.uuid4()) + channel_s.on(None, ref, callback2) + await channel_s.send("ping", "this is my ping payload", ref) + + # we give it some time to complete + await asyncio.sleep(15) + + # proper shut down + listen_task.cancel() + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) diff --git a/poetry.lock b/poetry.lock index a74d67a8..029988bc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "colorama" @@ -131,6 +131,16 @@ files = [ {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, ] +[[package]] +name = "uuid" +version = "1.30" +description = "UUID object and generation functions (Python 2.3 or higher)" +optional = false +python-versions = "*" +files = [ + {file = "uuid-1.30.tar.gz", hash = "sha256:1f87cc004ac5120466f36c5beae48b4c48cc411968eed0eaecd3da82aa96193f"}, +] + [[package]] name = "websockets" version = "11.0.3" @@ -213,4 +223,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "542d5628e562e1d06aba9216b8b4308a7d5723a64a146f5311aae7b18a9e69e5" +content-hash = "d5bdcceb9e4ab6423b4c727ea2a0b4cde830fa00fa0a9f9d72f4e0d9fad77f9e" diff --git a/pyproject.toml b/pyproject.toml index 59adaef6..0fc4d8ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ python = "^3.8" websockets = "^11.0" python-dateutil = "^2.8.1" typing-extensions = "^4.2.0" +uuid = "^1.30" [tool.poetry.dev-dependencies] pytest = "^7.2.0" diff --git a/realtime/channel.py b/realtime/channel.py index e4ac9084..829f4faa 100644 --- a/realtime/channel.py +++ b/realtime/channel.py @@ -1,8 +1,10 @@ from __future__ import annotations -import asyncio +import logging import json +import uuid from typing import Any, List, Dict, TYPE_CHECKING, NamedTuple +from realtime.message import * from realtime.types import Callback @@ -13,6 +15,7 @@ class CallbackListener(NamedTuple): """A tuple with `event` and `callback` """ event: str + ref: str callback: Callback @@ -20,6 +23,7 @@ class Channel: """ `Channel` is an abstraction for a topic listener for an existing socket connection. Each Channel has its own topic and a list of event-callbacks that responds to messages. + A client can also send messages to a channel and register callback when expecting replies. Should only be instantiated through `connection.Socket().set_channel(topic)` Topic-Channel has a 1-many relationship. """ @@ -35,45 +39,81 @@ def __init__(self, socket: Socket, topic: str, params: Dict[str, Any] = {}) -> N self.topic = topic self.listeners: List[CallbackListener] = [] self.joined = False + self.join_ref = str(uuid.uuid4()) + self.control_msg_ref = "" - def join(self) -> Channel: + async def join(self) -> None: """ - Wrapper for async def _join() to expose a non-async interface - Essentially gets the only event loop and attempt joining a topic - :return: Channel + Coroutine that attempts to join Phoenix Realtime server via a certain topic + :return: None """ - loop = asyncio.get_event_loop() # TODO: replace with get_running_loop - loop.run_until_complete(self._join()) - return self + if self.socket.version == 1: + join_req = dict(topic=self.topic, event=ChannelEvents.join, + payload={}, ref=None) + elif self.socket.version == 2: + #[join_reference, message_reference, topic_name, event_name, payload] + self.control_msg_ref = str(uuid.uuid4()) + join_req = [self.join_ref, self.control_msg_ref, self.topic, ChannelEvents.join, self.params] + + try: + await self.socket.ws_connection.send(json.dumps(join_req)) + except Exception as e: + print(e) + return - async def _join(self) -> None: + async def leave(self) -> None: """ - Coroutine that attempts to join Phoenix Realtime server via a certain topic + Coroutine that attempts to leave Phoenix Realtime server via a certain topic :return: None """ - join_req = dict(topic=self.topic, event="phx_join", + if self.socket.version == 1: + leave_req = dict(topic=self.topic, event=ChannelEvents.leave, payload={}, ref=None) + elif self.socket.version == 2: + leave_req = [self.join_ref, None, self.topic, ChannelEvents.leave, {}] try: - await self.socket.ws_connection.send(json.dumps(join_req)) + await self.socket.ws_connection.send(json.dumps(leave_req)) except Exception as e: - print(str(e)) # TODO: better error propagation + print(e) return - def on(self, event: str, callback: Callback) -> Channel: + def on(self, event: str, ref: str, callback: Callback) -> Channel: """ :param event: A specific event will have a specific callback + :param ref: A specific reference that will have a specific callback :param callback: Callback that takes msg payload as its first argument :return: Channel """ - cl = CallbackListener(event=event, callback=callback) + cl = CallbackListener(event=event, ref=ref, callback=callback) self.listeners.append(cl) return self - def off(self, event: str) -> None: + def off(self, event: str, ref: str) -> None: """ :param event: Stop responding to a certain event + :param event: Stop responding to a certain reference :return: None """ self.listeners = [ - callback for callback in self.listeners if callback.event != event] + callback for callback in self.listeners if (callback.event != event and callback.ref != ref)] + + async def send(self, event_name: str, payload: str, ref: str) -> None: + """ + Coroutine that attempts to join Phoenix Realtime server via a certain topic + :param event_name: The event_name: it must match the first argument of a handle_in function on the server channel module. + :param payload: The payload to be sent to the phoenix server + :param ref: The message reference that the server will use for replying + :return: None + """ + if self.socket.version == 1: + msg = dict(topic=self.topic, event=event_name, + payload=payload, ref=None) + elif self.socket.version == 2: + msg = [None, ref, self.topic, event_name, payload] + + try: + await self.socket.ws_connection.send(json.dumps(msg)) + except Exception as e: + print(e) + return \ No newline at end of file diff --git a/realtime/connection.py b/realtime/connection.py index cc017a36..25b3ed89 100644 --- a/realtime/connection.py +++ b/realtime/connection.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import pdb from collections import defaultdict from functools import wraps from typing import Any, Callable, List, Dict, TypeVar, DefaultDict @@ -31,14 +32,15 @@ def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval: class Socket: - def __init__(self, url: str, auto_reconnect: bool = False, params: Dict[str, Any] = {}, hb_interval: int = 5) -> None: + def __init__(self, url: str, auto_reconnect: bool = False, params: Dict[str, Any] = {}, hb_interval: int = 30, version: int = 2) -> None: """ `Socket` is the abstraction for an actual socket connection that receives and 'reroutes' `Message` according to its `topic` and `event`. Socket-Channel has a 1-many relationship. Socket-Topic has a 1-many relationship. :param url: Websocket URL of the Realtime server. starts with `ws://` or `wss://` :param params: Optional parameters for connection. - :param hb_interval: WS connection is kept alive by sending a heartbeat message. Optional, defaults to 5. + :param hb_interval: WS connection is kept alive by sending a heartbeat message. Optional, defaults to 30. + :param version: phoenix JSON serializer version. """ self.url = url self.channels = defaultdict(list) @@ -46,39 +48,53 @@ def __init__(self, url: str, auto_reconnect: bool = False, params: Dict[str, Any self.params = params self.hb_interval = hb_interval self.ws_connection: websockets.client.WebSocketClientProtocol - self.kept_alive = False + self.kept_alive = set() self.auto_reconnect = auto_reconnect + self.version = version self.channels: DefaultDict[str, List[Channel]] = defaultdict(list) @ensure_connection - def listen(self) -> None: - """ - Wrapper for async def _listen() to expose a non-async interface - In most cases, this should be the last method executed as it starts an infinite listening loop. - :return: None - """ - loop = asyncio.get_event_loop() # TODO: replace with get_running_loop - loop.run_until_complete(asyncio.gather( - self._listen(), self._keep_alive())) - - async def _listen(self) -> None: + async def listen(self) -> None: """ An infinite loop that keeps listening. :return: None """ + self.kept_alive.add(asyncio.ensure_future(self.keep_alive())) + while True: try: msg = await self.ws_connection.recv() - msg = Message(**json.loads(msg)) - + if self.version == 1 : + msg = Message(**json.loads(msg)) + elif self.version == 2: + msg_array = json.loads(msg) + msg = Message(join_ref=msg_array[0], ref= msg_array[1], topic=msg_array[2], event= msg_array[3], payload= msg_array[4]) if msg.event == ChannelEvents.reply: - continue + for channel in self.channels.get(msg.topic, []): + if msg.ref == channel.control_msg_ref : + if msg.payload["status"] == "error": + logging.info(f"Error joining channel: {msg.topic} - {msg.payload['response']['reason']}") + break + elif msg.payload["status"] == "ok": + logging.info(f"Successfully joined {msg.topic}") + continue + else: + for cl in channel.listeners: + if cl.ref in ["*", msg.ref]: + cl.callback(msg.payload) + + if msg.event == ChannelEvents.close: + for channel in self.channels.get(msg.topic, []): + if msg.join_ref == channel.join_ref : + logging.info(f"Successfully left {msg.topic}") + continue for channel in self.channels.get(msg.topic, []): for cl in channel.listeners: if cl.event in ["*", msg.event]: cl.callback(msg.payload) + except websockets.exceptions.ConnectionClosed: if self.auto_reconnect: logging.info("Connection with server closed, trying to reconnect...") @@ -90,37 +106,46 @@ async def _listen(self) -> None: logging.exception("Connection with the server closed.") break - def connect(self) -> None: - """ - Wrapper for async def _connect() to expose a non-async interface - """ - loop = asyncio.get_event_loop() # TODO: replace with get_running - loop.run_until_complete(self._connect()) - self.connected = True + except asyncio.CancelledError: + logging.info("Listen task was cancelled.") + await self.leave_all() - async def _connect(self) -> None: + except Exception as e: + logging.error(f"Unexpected error in listen: {e}") + + async def connect(self) -> None: ws_connection = await websockets.connect(self.url) if ws_connection.open: - logging.info("Connection was successful") self.ws_connection = ws_connection self.connected = True + logging.info("Connection was successful") else: raise Exception("Connection Failed") + + async def leave_all(self) -> None: + for channel in self.channels: + for chan in self.channels.get(channel, []): + await chan.leave() - async def _keep_alive(self) -> None: + async def keep_alive(self) -> None: """ Sending heartbeat to server every 5 seconds Ping - pong messages to verify connection is alive """ while True: try: - data = dict( - topic=PHOENIX_CHANNEL, - event=ChannelEvents.heartbeat, - payload=HEARTBEAT_PAYLOAD, - ref=None, - ) + if self.version == 1 : + data = dict( + topic=PHOENIX_CHANNEL, + event=ChannelEvents.heartbeat, + payload=HEARTBEAT_PAYLOAD, + ref=None, + ) + elif self.version == 2 : + # [null,"4","phoenix","heartbeat",{}] + data = [None, None, PHOENIX_CHANNEL, ChannelEvents.heartbeat, HEARTBEAT_PAYLOAD] + await self.ws_connection.send(json.dumps(data)) await asyncio.sleep(self.hb_interval) except websockets.exceptions.ConnectionClosed: @@ -144,10 +169,10 @@ def set_channel(self, topic: str) -> Channel: def summary(self) -> None: """ - Prints a list of topics and event the socket is listening to + Prints a list of topics and event, and reference that the socket is listening to :return: None """ for topic, chans in self.channels.items(): for chan in chans: print( - f"Topic: {topic} | Events: {[e for e, _ in chan.callbacks]}]") + f"Topic: {topic} | Events: {[e for e, _, _ in chan.listeners]} | References: {[r for _, r, _ in chan.listeners]}]") diff --git a/realtime/message.py b/realtime/message.py index 9909d4db..87da6e0e 100644 --- a/realtime/message.py +++ b/realtime/message.py @@ -11,6 +11,7 @@ class Message: event: str payload: Dict[str, Any] ref: Any + join_ref: Any topic: str def __hash__(self): @@ -32,4 +33,4 @@ class ChannelEvents(str, Enum): PHOENIX_CHANNEL = "phoenix" -HEARTBEAT_PAYLOAD = {"msg": "ping"} +HEARTBEAT_PAYLOAD = {}