Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
Closed
55 changes: 55 additions & 0 deletions async-usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from realtime.connection import Socket
import asyncio
import uuid

def callback1(payload):
print(f"c1: {payload}")

def callback2(payload):
print(f"c2: {payload}")


async def main():

TOKEN = ""
URLsink = f"ws://127.0.0.1:4000/socket/websocket?token={TOKEN}&vsn=2.0.0"

client = Socket(URLsink)

await client.connect()

# fire and forget the listening routine
listen_task = asyncio.ensure_future(client.listen())

channel_s = client.set_channel("yourchannel")
await channel_s.join()
channel_s.on("test_event", None, callback1)

# non sense elixir handler, we would not have an event on a reply
#def handle_in("request_ping", payload, socket) do
# push(socket, "test_event", %{body: payload})
# {:noreply, socket}
#end

await channel_s.send("request_ping", "this is my payload 1", None)
await channel_s.send("request_ping", "this is my payload 2", None)
await channel_s.send("request_ping", "this is my payload 3", None)

# proper relpy elixir handler
#def handle_in("ping", payload, socket) do
# {:reply, {:ok, payload}, socket}
#end

ref = str(uuid.uuid4())
channel_s.on(None, ref, callback2)
await channel_s.send("ping", "this is my ping payload", ref)

# we give it some time to complete
await asyncio.sleep(15)

# proper shut down
listen_task.cancel()

if __name__ == '__main__':
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
14 changes: 12 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ python = "^3.8"
websockets = "^11.0"
python-dateutil = "^2.8.1"
typing-extensions = "^4.2.0"
uuid = "^1.30"

[tool.poetry.dev-dependencies]
pytest = "^7.2.0"
Expand Down
74 changes: 57 additions & 17 deletions realtime/channel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import asyncio
import logging
import json
import uuid
from typing import Any, List, Dict, TYPE_CHECKING, NamedTuple
from realtime.message import *

from realtime.types import Callback

Expand All @@ -13,13 +15,15 @@
class CallbackListener(NamedTuple):
"""A tuple with `event` and `callback` """
event: str
ref: str
callback: Callback


class Channel:
"""
`Channel` is an abstraction for a topic listener for an existing socket connection.
Each Channel has its own topic and a list of event-callbacks that responds to messages.
A client can also send messages to a channel and register callback when expecting replies.
Should only be instantiated through `connection.Socket().set_channel(topic)`
Topic-Channel has a 1-many relationship.
"""
Expand All @@ -35,45 +39,81 @@ def __init__(self, socket: Socket, topic: str, params: Dict[str, Any] = {}) -> N
self.topic = topic
self.listeners: List[CallbackListener] = []
self.joined = False
self.join_ref = str(uuid.uuid4())
self.control_msg_ref = ""

def join(self) -> Channel:
async def join(self) -> None:
"""
Wrapper for async def _join() to expose a non-async interface
Essentially gets the only event loop and attempt joining a topic
:return: Channel
Coroutine that attempts to join Phoenix Realtime server via a certain topic
:return: None
"""
loop = asyncio.get_event_loop() # TODO: replace with get_running_loop
loop.run_until_complete(self._join())
return self
if self.socket.version == 1:
join_req = dict(topic=self.topic, event=ChannelEvents.join,
payload={}, ref=None)
elif self.socket.version == 2:
#[join_reference, message_reference, topic_name, event_name, payload]
self.control_msg_ref = str(uuid.uuid4())
join_req = [self.join_ref, self.control_msg_ref, self.topic, ChannelEvents.join, self.params]

try:
await self.socket.ws_connection.send(json.dumps(join_req))
except Exception as e:
print(e)
return

async def _join(self) -> None:
async def leave(self) -> None:
"""
Coroutine that attempts to join Phoenix Realtime server via a certain topic
Coroutine that attempts to leave Phoenix Realtime server via a certain topic
:return: None
"""
join_req = dict(topic=self.topic, event="phx_join",
if self.socket.version == 1:
leave_req = dict(topic=self.topic, event=ChannelEvents.leave,
payload={}, ref=None)
elif self.socket.version == 2:
leave_req = [self.join_ref, None, self.topic, ChannelEvents.leave, {}]

try:
await self.socket.ws_connection.send(json.dumps(join_req))
await self.socket.ws_connection.send(json.dumps(leave_req))
except Exception as e:
print(str(e)) # TODO: better error propagation
print(e)
return

def on(self, event: str, callback: Callback) -> Channel:
def on(self, event: str, ref: str, callback: Callback) -> Channel:
"""
:param event: A specific event will have a specific callback
:param ref: A specific reference that will have a specific callback
:param callback: Callback that takes msg payload as its first argument
:return: Channel
"""
cl = CallbackListener(event=event, callback=callback)
cl = CallbackListener(event=event, ref=ref, callback=callback)
self.listeners.append(cl)
return self

def off(self, event: str) -> None:
def off(self, event: str, ref: str) -> None:
"""
:param event: Stop responding to a certain event
:param event: Stop responding to a certain reference
:return: None
"""
self.listeners = [
callback for callback in self.listeners if callback.event != event]
callback for callback in self.listeners if (callback.event != event and callback.ref != ref)]

async def send(self, event_name: str, payload: str, ref: str) -> None:
"""
Coroutine that attempts to join Phoenix Realtime server via a certain topic
:param event_name: The event_name: it must match the first argument of a handle_in function on the server channel module.
:param payload: The payload to be sent to the phoenix server
:param ref: The message reference that the server will use for replying
:return: None
"""
if self.socket.version == 1:
msg = dict(topic=self.topic, event=event_name,
payload=payload, ref=None)
elif self.socket.version == 2:
msg = [None, ref, self.topic, event_name, payload]

try:
await self.socket.ws_connection.send(json.dumps(msg))
except Exception as e:
print(e)
return
95 changes: 60 additions & 35 deletions realtime/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import logging
import pdb
from collections import defaultdict
from functools import wraps
from typing import Any, Callable, List, Dict, TypeVar, DefaultDict
Expand Down Expand Up @@ -31,54 +32,69 @@ def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:


class Socket:
def __init__(self, url: str, auto_reconnect: bool = False, params: Dict[str, Any] = {}, hb_interval: int = 5) -> None:
def __init__(self, url: str, auto_reconnect: bool = False, params: Dict[str, Any] = {}, hb_interval: int = 30, version: int = 2) -> None:
"""
`Socket` is the abstraction for an actual socket connection that receives and 'reroutes' `Message` according to its `topic` and `event`.
Socket-Channel has a 1-many relationship.
Socket-Topic has a 1-many relationship.
:param url: Websocket URL of the Realtime server. starts with `ws://` or `wss://`
:param params: Optional parameters for connection.
:param hb_interval: WS connection is kept alive by sending a heartbeat message. Optional, defaults to 5.
:param hb_interval: WS connection is kept alive by sending a heartbeat message. Optional, defaults to 30.
:param version: phoenix JSON serializer version.
"""
self.url = url
self.channels = defaultdict(list)
self.connected = False
self.params = params
self.hb_interval = hb_interval
self.ws_connection: websockets.client.WebSocketClientProtocol
self.kept_alive = False
self.kept_alive = set()
self.auto_reconnect = auto_reconnect
self.version = version

self.channels: DefaultDict[str, List[Channel]] = defaultdict(list)

@ensure_connection
def listen(self) -> None:
"""
Wrapper for async def _listen() to expose a non-async interface
In most cases, this should be the last method executed as it starts an infinite listening loop.
:return: None
"""
loop = asyncio.get_event_loop() # TODO: replace with get_running_loop
loop.run_until_complete(asyncio.gather(
self._listen(), self._keep_alive()))

async def _listen(self) -> None:
async def listen(self) -> None:
"""
An infinite loop that keeps listening.
:return: None
"""
self.kept_alive.add(asyncio.ensure_future(self.keep_alive()))

while True:
try:
msg = await self.ws_connection.recv()
msg = Message(**json.loads(msg))

if self.version == 1 :
msg = Message(**json.loads(msg))
elif self.version == 2:
msg_array = json.loads(msg)
msg = Message(join_ref=msg_array[0], ref= msg_array[1], topic=msg_array[2], event= msg_array[3], payload= msg_array[4])
if msg.event == ChannelEvents.reply:
continue
for channel in self.channels.get(msg.topic, []):
if msg.ref == channel.control_msg_ref :
if msg.payload["status"] == "error":
logging.info(f"Error joining channel: {msg.topic} - {msg.payload['response']['reason']}")
break
elif msg.payload["status"] == "ok":
logging.info(f"Successfully joined {msg.topic}")
continue
else:
for cl in channel.listeners:
if cl.ref in ["*", msg.ref]:
cl.callback(msg.payload)

if msg.event == ChannelEvents.close:
for channel in self.channels.get(msg.topic, []):
if msg.join_ref == channel.join_ref :
logging.info(f"Successfully left {msg.topic}")
continue

for channel in self.channels.get(msg.topic, []):
for cl in channel.listeners:
if cl.event in ["*", msg.event]:
cl.callback(msg.payload)

except websockets.exceptions.ConnectionClosed:
if self.auto_reconnect:
logging.info("Connection with server closed, trying to reconnect...")
Expand All @@ -90,37 +106,46 @@ async def _listen(self) -> None:
logging.exception("Connection with the server closed.")
break

def connect(self) -> None:
"""
Wrapper for async def _connect() to expose a non-async interface
"""
loop = asyncio.get_event_loop() # TODO: replace with get_running
loop.run_until_complete(self._connect())
self.connected = True
except asyncio.CancelledError:
logging.info("Listen task was cancelled.")
await self.leave_all()

async def _connect(self) -> None:
except Exception as e:
logging.error(f"Unexpected error in listen: {e}")

async def connect(self) -> None:
ws_connection = await websockets.connect(self.url)

if ws_connection.open:
logging.info("Connection was successful")
self.ws_connection = ws_connection
self.connected = True
logging.info("Connection was successful")
else:
raise Exception("Connection Failed")

async def leave_all(self) -> None:
for channel in self.channels:
for chan in self.channels.get(channel, []):
await chan.leave()

async def _keep_alive(self) -> None:
async def keep_alive(self) -> None:
"""
Sending heartbeat to server every 5 seconds
Ping - pong messages to verify connection is alive
"""
while True:
try:
data = dict(
topic=PHOENIX_CHANNEL,
event=ChannelEvents.heartbeat,
payload=HEARTBEAT_PAYLOAD,
ref=None,
)
if self.version == 1 :
data = dict(
topic=PHOENIX_CHANNEL,
event=ChannelEvents.heartbeat,
payload=HEARTBEAT_PAYLOAD,
ref=None,
)
elif self.version == 2 :
# [null,"4","phoenix","heartbeat",{}]
data = [None, None, PHOENIX_CHANNEL, ChannelEvents.heartbeat, HEARTBEAT_PAYLOAD]

await self.ws_connection.send(json.dumps(data))
await asyncio.sleep(self.hb_interval)
except websockets.exceptions.ConnectionClosed:
Expand All @@ -144,10 +169,10 @@ def set_channel(self, topic: str) -> Channel:

def summary(self) -> None:
"""
Prints a list of topics and event the socket is listening to
Prints a list of topics and event, and reference that the socket is listening to
:return: None
"""
for topic, chans in self.channels.items():
for chan in chans:
print(
f"Topic: {topic} | Events: {[e for e, _ in chan.callbacks]}]")
f"Topic: {topic} | Events: {[e for e, _, _ in chan.listeners]} | References: {[r for _, r, _ in chan.listeners]}]")
Loading