Skip to content

Commit 143d333

Browse files
committed
LISTEN/NOTIFY funcionality
Signed-off-by: chandr-andr (Kiselev Aleksandr) <chandr@chandr.net>
1 parent 6b22424 commit 143d333

File tree

5 files changed

+173
-51
lines changed

5 files changed

+173
-51
lines changed

python/psqlpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
IsolationLevel,
88
KeepaliveConfig,
99
Listener,
10-
ListenerNotification,
10+
ListenerNotificationMsg,
1111
LoadBalanceHosts,
1212
QueryResult,
1313
ReadVariant,
@@ -28,7 +28,7 @@
2828
"IsolationLevel",
2929
"KeepaliveConfig",
3030
"Listener",
31-
"ListenerNotification",
31+
"ListenerNotificationMsg",
3232
"LoadBalanceHosts",
3333
"QueryResult",
3434
"ReadVariant",

python/psqlpy/_internal/__init__.pyi

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import types
22
from enum import Enum
33
from io import BytesIO
44
from ipaddress import IPv4Address, IPv6Address
5-
from typing import Any, Callable, Sequence, TypeVar
5+
from typing import Any, Awaitable, Callable, Sequence, TypeVar
66

77
from typing_extensions import Buffer, Self
88

@@ -1360,6 +1360,9 @@ class ConnectionPool:
13601360
res = await connection.execute(...)
13611361
```
13621362
"""
1363+
def listener(self: Self) -> Listener:
1364+
"""Create new listener."""
1365+
13631366
def close(self: Self) -> None:
13641367
"""Close the connection pool."""
13651368

@@ -1752,6 +1755,70 @@ class ConnectionPoolBuilder:
17521755
class Listener:
17531756
"""Result."""
17541757

1758+
connection: Connection
17551759

1756-
class ListenerNotification:
1757-
"""Result."""
1760+
def __aiter__(self: Self) -> Self: ...
1761+
async def __anext__(self: Self) -> ListenerNotificationMsg: ...
1762+
async def __aenter__(self: Self) -> Self: ...
1763+
async def __aexit__(
1764+
self: Self,
1765+
exception_type: type[BaseException] | None,
1766+
exception: BaseException | None,
1767+
traceback: types.TracebackType | None,
1768+
) -> None: ...
1769+
async def startup(self: Self) -> None:
1770+
"""Startup the listener.
1771+
1772+
Each listener MUST be started up.
1773+
"""
1774+
async def add_callback(
1775+
self: Self,
1776+
channel: str,
1777+
callback: Callable[
1778+
[str, str, int, Connection], Awaitable[None],
1779+
],
1780+
) -> None:
1781+
"""Add callback to the channel.
1782+
1783+
Callback must be async function and have signature like this:
1784+
```python
1785+
async def callback(
1786+
channel: str,
1787+
payload: str,
1788+
process_id: str,
1789+
connection: Connection,
1790+
) -> None:
1791+
...
1792+
```
1793+
"""
1794+
1795+
async def clear_channel_callbacks(self, channel: str) -> None:
1796+
"""Remove all callbacks for the channel.
1797+
1798+
### Parameters:
1799+
- `channel`: name of the channel.
1800+
"""
1801+
1802+
async def listen(self: Self) -> None:
1803+
"""Start listening.
1804+
1805+
Start actual listening.
1806+
In the background it creates task in Rust event loop.
1807+
You must save returned Future to the array.
1808+
"""
1809+
1810+
async def abort_listen(self: Self) -> None:
1811+
"""Abort listen.
1812+
1813+
If `listen()` method was called, stop listening,
1814+
else don't do anything.
1815+
"""
1816+
1817+
1818+
class ListenerNotificationMsg:
1819+
"""Listener message in async iterator."""
1820+
1821+
process_id: int
1822+
channel: str
1823+
payload: str
1824+
connection: Connection

src/driver/connection_pool.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ impl ConnectionPool {
500500
Connection::new(None, Some(self.pool.clone()))
501501
}
502502

503-
pub async fn add_listener(
503+
pub fn listener(
504504
self_: pyo3::Py<Self>,
505505
) -> RustPSQLDriverPyResult<Listener> {
506506
let (pg_config, ca_file, ssl_mode) = pyo3::Python::with_gil(|gil| {

src/driver/listener.rs

Lines changed: 99 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,10 @@ impl ChannelCallbacks {
5555

5656

5757
#[derive(Clone, Debug)]
58-
#[pyclass]
5958
pub struct ListenerNotification {
60-
process_id: i32,
61-
channel: String,
62-
payload: String,
59+
pub process_id: i32,
60+
pub channel: String,
61+
pub payload: String,
6362
}
6463

6564
impl From::<Notification> for ListenerNotification {
@@ -73,13 +72,47 @@ impl From::<Notification> for ListenerNotification {
7372
}
7473

7574
#[pyclass]
76-
struct ListenerNotificationMsg {
75+
pub struct ListenerNotificationMsg {
7776
process_id: i32,
7877
channel: String,
7978
payload: String,
8079
connection: Connection,
8180
}
8281

82+
#[pymethods]
83+
impl ListenerNotificationMsg {
84+
#[getter]
85+
fn process_id(&self) -> i32 {
86+
self.process_id
87+
}
88+
89+
#[getter]
90+
fn channel(&self) -> String {
91+
self.channel.clone()
92+
}
93+
94+
#[getter]
95+
fn payload(&self) -> String {
96+
self.payload.clone()
97+
}
98+
99+
#[getter]
100+
fn connection(&self) -> Connection {
101+
self.connection.clone()
102+
}
103+
}
104+
105+
impl ListenerNotificationMsg {
106+
fn new(value: ListenerNotification, conn: Connection) -> Self {
107+
ListenerNotificationMsg {
108+
process_id: value.process_id,
109+
channel: String::from(value.channel),
110+
payload: String::from(value.payload),
111+
connection: conn,
112+
}
113+
}
114+
}
115+
83116
struct ListenerCallback {
84117
task_locals: Option<TaskLocals>,
85118
callback: Py<PyAny>,
@@ -111,7 +144,15 @@ impl ListenerCallback {
111144
if let Some(task_locals) = task_locals {
112145
tokio_runtime().spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move {
113146
let future = Python::with_gil(|py| {
114-
let awaitable = callback.call1(py, (lister_notification, connection)).unwrap();
147+
let awaitable = callback.call1(
148+
py,
149+
(
150+
lister_notification.channel,
151+
lister_notification.payload,
152+
lister_notification.process_id,
153+
connection,
154+
)
155+
).unwrap();
115156
pyo3_async_runtimes::tokio::into_future(awaitable.into_bound(py)).unwrap()
116157
});
117158
future.await.unwrap();
@@ -226,6 +267,41 @@ impl Listener {
226267
return Err(RustPSQLDriverError::ListenerClosedError)
227268
}
228269

270+
fn __anext__(&self) -> RustPSQLDriverPyResult<Option<Py<PyAny>>> {
271+
let Some(client) = self.connection.db_client() else {
272+
return Err(RustPSQLDriverError::ListenerStartError(
273+
"Listener doesn't have underlying client, please call startup".into(),
274+
));
275+
};
276+
let Some(receiver) = self.receiver.clone() else {
277+
return Err(RustPSQLDriverError::ListenerStartError(
278+
"Listener doesn't have underlying receiver, please call startup".into(),
279+
));
280+
};
281+
282+
let is_listened_clone = self.is_listened.clone();
283+
let listen_query_clone = self.listen_query.clone();
284+
let connection = self.connection.clone();
285+
286+
let py_future = Python::with_gil(move |gil| {
287+
rustdriver_future(gil, async move {
288+
{
289+
call_listen(&is_listened_clone, &listen_query_clone, &client).await?;
290+
};
291+
let next_element = {
292+
let mut write_receiver = receiver.write().await;
293+
write_receiver.next().await
294+
};
295+
296+
let inner_notification = process_message(next_element)?;
297+
298+
Ok(ListenerNotificationMsg::new(inner_notification, connection))
299+
})
300+
});
301+
302+
Ok(Some(py_future?))
303+
}
304+
229305
#[getter]
230306
fn connection(&self) -> Connection {
231307
self.connection.clone()
@@ -280,40 +356,6 @@ impl Listener {
280356
Ok(())
281357
}
282358

283-
fn __anext__(&self) -> RustPSQLDriverPyResult<Option<PyObject>> {
284-
let Some(client) = self.connection.db_client() else {
285-
return Err(RustPSQLDriverError::ListenerStartError(
286-
"Listener doesn't have underlying client, please call startup".into(),
287-
));
288-
};
289-
let Some(receiver) = self.receiver.clone() else {
290-
return Err(RustPSQLDriverError::ListenerStartError(
291-
"Listener doesn't have underlying receiver, please call startup".into(),
292-
));
293-
};
294-
295-
let is_listened_clone = self.is_listened.clone();
296-
let listen_query_clone = self.listen_query.clone();
297-
298-
let py_future = Python::with_gil(move |gil| {
299-
rustdriver_future(gil, async move {
300-
{
301-
call_listen(&is_listened_clone, &listen_query_clone, &client).await?;
302-
};
303-
let next_element = {
304-
let mut write_receiver = receiver.write().await;
305-
write_receiver.next().await
306-
};
307-
308-
let inner_notification = process_message(next_element)?;
309-
310-
Ok(inner_notification)
311-
})
312-
});
313-
314-
Ok(Some(py_future?))
315-
}
316-
317359
#[pyo3(signature = (channel, callback))]
318360
async fn add_callback(
319361
&mut self,
@@ -337,10 +379,6 @@ impl Listener {
337379
callback,
338380
);
339381

340-
// let awaitable = callback.call1(()).unwrap();
341-
// println!("8888888 {:?}", awaitable);
342-
// let bbb = pyo3_async_runtimes::tokio::into_future(awaitable).unwrap();
343-
// println!("999999");
344382
{
345383
let mut write_channel_callbacks = self.channel_callbacks.write().await;
346384
write_channel_callbacks.add_callback(channel, listener_callback);
@@ -351,6 +389,15 @@ impl Listener {
351389
Ok(())
352390
}
353391

392+
async fn clear_channel_callbacks(&mut self, channel: String) {
393+
{
394+
let mut write_channel_callbacks = self.channel_callbacks.write().await;
395+
write_channel_callbacks.clear_channel_callbacks(channel);
396+
}
397+
398+
self.update_listen_query().await;
399+
}
400+
354401
async fn listen(&mut self) -> RustPSQLDriverPyResult<()> {
355402
let Some(client) = self.connection.db_client() else {
356403
return Err(RustPSQLDriverError::BaseConnectionError("test".into()));
@@ -401,6 +448,14 @@ impl Listener {
401448

402449
Ok(())
403450
}
451+
452+
async fn abort_listen(&mut self) {
453+
if let Some(listen_abort_handler) = &self.listen_abort_handler {
454+
listen_abort_handler.abort();
455+
}
456+
457+
self.listen_abort_handler = None;
458+
}
404459
}
405460

406461
async fn dispatch_callback(

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn psqlpy(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> {
3030
pymod.add_class::<driver::transaction::Transaction>()?;
3131
pymod.add_class::<driver::cursor::Cursor>()?;
3232
pymod.add_class::<driver::listener::Listener>()?;
33-
pymod.add_class::<driver::listener::ListenerNotification>()?;
33+
pymod.add_class::<driver::listener::ListenerNotificationMsg>()?;
3434
pymod.add_class::<driver::transaction_options::IsolationLevel>()?;
3535
pymod.add_class::<driver::transaction_options::SynchronousCommit>()?;
3636
pymod.add_class::<driver::transaction_options::ReadVariant>()?;

0 commit comments

Comments
 (0)