Skip to content

Commit d5d4458

Browse files
committed
LISTEN/NOTIFY funcionality
Signed-off-by: chandr-andr (Kiselev Aleksandr) <chandr@chandr.net>
2 parents 143d333 + 2a14188 commit d5d4458

File tree

6 files changed

+167
-163
lines changed

6 files changed

+167
-163
lines changed

src/driver/connection.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ impl InnerConnection {
3939
}
4040
}
4141
}
42-
42+
4343
pub async fn query<T>(
4444
&self,
4545
statement: &T,
@@ -71,8 +71,8 @@ impl InnerConnection {
7171
&self,
7272
statement: &T,
7373
params: &[&(dyn ToSql + Sync)],
74-
) -> RustPSQLDriverPyResult<Row>
75-
where T: ?Sized + ToStatement
74+
) -> RustPSQLDriverPyResult<Row>
75+
where T: ?Sized + ToStatement
7676
{
7777
match self {
7878
InnerConnection::PoolConn(pconn) => {
@@ -87,7 +87,7 @@ impl InnerConnection {
8787
pub async fn copy_in<T, U>(
8888
&self,
8989
statement: &T
90-
) -> RustPSQLDriverPyResult<CopyInSink<U>>
90+
) -> RustPSQLDriverPyResult<CopyInSink<U>>
9191
where
9292
T: ?Sized + ToStatement,
9393
U: Buf + 'static + Send

src/driver/connection_pool.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@ use crate::{
1111
};
1212

1313
use super::{
14-
common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs},
15-
connection::{Connection, InnerConnection},
16-
listener::Listener,
17-
utils::{build_connection_config, build_manager, build_tls},
14+
common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, connection::{Connection, InnerConnection}, listener::core::Listener, utils::{build_connection_config, build_manager, build_tls}
1815
};
1916

2017
/// Make new connection pool.

src/driver/listener.rs renamed to src/driver/listener/core.rs

Lines changed: 5 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1,167 +1,19 @@
1-
use std::{
2-
collections::{hash_map::Entry, HashMap},
3-
sync::Arc,
4-
};
1+
use std::sync::Arc;
52

63
use futures::{stream, FutureExt, StreamExt, TryStreamExt};
74
use futures_channel::mpsc::UnboundedReceiver;
85
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
96
use postgres_openssl::MakeTlsConnector;
10-
use pyo3::{pyclass, pymethods, Py, PyAny, PyErr, PyObject, Python};
11-
use pyo3_async_runtimes::TaskLocals;
7+
use pyo3::{pyclass, pymethods, Py, PyAny, PyErr, Python};
128
use tokio::{sync::RwLock, task::{AbortHandle, JoinHandle}};
13-
use tokio_postgres::{AsyncMessage, Config, Notification};
9+
use tokio_postgres::{AsyncMessage, Config};
1410

1511
use crate::{
16-
driver::utils::is_coroutine_function, exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, runtime::{rustdriver_future, tokio_runtime}
17-
};
18-
19-
use super::{
20-
common_options::SslMode, connection::{Connection, InnerConnection}, utils::{build_tls, ConfiguredTLS}
12+
driver::{common_options::SslMode, connection::{Connection, InnerConnection}, utils::{build_tls, is_coroutine_function, ConfiguredTLS}}, exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, runtime::{rustdriver_future, tokio_runtime}
2113
};
2214

23-
struct ChannelCallbacks(HashMap<String, Vec<ListenerCallback>>);
24-
25-
impl Default for ChannelCallbacks {
26-
fn default() -> Self {
27-
ChannelCallbacks(Default::default())
28-
}
29-
}
30-
31-
impl ChannelCallbacks {
32-
fn add_callback(&mut self, channel: String, callback: ListenerCallback) {
33-
match self.0.entry(channel) {
34-
Entry::Vacant(e) => {
35-
e.insert(vec![callback]);
36-
}
37-
Entry::Occupied(mut e) => {
38-
e.get_mut().push(callback);
39-
}
40-
};
41-
}
42-
43-
fn retrieve_channel_callbacks(&self, channel: String) -> Option<&Vec<ListenerCallback>> {
44-
self.0.get(&channel)
45-
}
46-
47-
fn clear_channel_callbacks(&mut self, channel: String) {
48-
self.0.remove(&channel);
49-
}
50-
51-
fn retrieve_all_channels(&self) -> Vec<&String> {
52-
self.0.keys().collect::<Vec<&String>>()
53-
}
54-
}
55-
56-
57-
#[derive(Clone, Debug)]
58-
pub struct ListenerNotification {
59-
pub process_id: i32,
60-
pub channel: String,
61-
pub payload: String,
62-
}
63-
64-
impl From::<Notification> for ListenerNotification {
65-
fn from(value: Notification) -> Self {
66-
ListenerNotification {
67-
process_id: value.process_id(),
68-
channel: String::from(value.channel()),
69-
payload: String::from(value.payload()),
70-
}
71-
}
72-
}
73-
74-
#[pyclass]
75-
pub struct ListenerNotificationMsg {
76-
process_id: i32,
77-
channel: String,
78-
payload: String,
79-
connection: Connection,
80-
}
15+
use super::structs::{ChannelCallbacks, ListenerCallback, ListenerNotification, ListenerNotificationMsg};
8116

82-
#[pymethods]
83-
impl ListenerNotificationMsg {
84-
#[getter]
85-
fn process_id(&self) -> i32 {
86-
self.process_id
87-
}
88-
89-
#[getter]
90-
fn channel(&self) -> String {
91-
self.channel.clone()
92-
}
93-
94-
#[getter]
95-
fn payload(&self) -> String {
96-
self.payload.clone()
97-
}
98-
99-
#[getter]
100-
fn connection(&self) -> Connection {
101-
self.connection.clone()
102-
}
103-
}
104-
105-
impl ListenerNotificationMsg {
106-
fn new(value: ListenerNotification, conn: Connection) -> Self {
107-
ListenerNotificationMsg {
108-
process_id: value.process_id,
109-
channel: String::from(value.channel),
110-
payload: String::from(value.payload),
111-
connection: conn,
112-
}
113-
}
114-
}
115-
116-
struct ListenerCallback {
117-
task_locals: Option<TaskLocals>,
118-
callback: Py<PyAny>,
119-
}
120-
121-
impl ListenerCallback {
122-
pub fn new(
123-
task_locals: Option<TaskLocals>,
124-
callback: Py<PyAny>,
125-
) -> Self {
126-
ListenerCallback {
127-
task_locals,
128-
callback,
129-
}
130-
}
131-
132-
async fn call(
133-
&self,
134-
lister_notification: ListenerNotification,
135-
connection: Connection,
136-
) -> RustPSQLDriverPyResult<()> {
137-
let (callback, task_locals) = Python::with_gil(|py| {
138-
if let Some(task_locals) = &self.task_locals {
139-
return (self.callback.clone(), Some(task_locals.clone_ref(py)));
140-
}
141-
(self.callback.clone(), None)
142-
});
143-
144-
if let Some(task_locals) = task_locals {
145-
tokio_runtime().spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move {
146-
let future = Python::with_gil(|py| {
147-
let awaitable = callback.call1(
148-
py,
149-
(
150-
lister_notification.channel,
151-
lister_notification.payload,
152-
lister_notification.process_id,
153-
connection,
154-
)
155-
).unwrap();
156-
pyo3_async_runtimes::tokio::into_future(awaitable.into_bound(py)).unwrap()
157-
});
158-
future.await.unwrap();
159-
})).await?;
160-
};
161-
162-
Ok(())
163-
}
164-
}
16517

16618
#[pyclass]
16719
pub struct Listener {

src/driver/listener/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub mod structs;
2+
pub mod core;

src/driver/listener/structs.rs

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
use std::collections::{hash_map::Entry, HashMap};
2+
3+
use pyo3::{pyclass, pymethods, Py, PyAny, Python};
4+
use pyo3_async_runtimes::TaskLocals;
5+
use tokio_postgres::Notification;
6+
7+
use crate::{
8+
driver::connection::Connection, exceptions::rust_errors::RustPSQLDriverPyResult, runtime::tokio_runtime
9+
};
10+
11+
12+
pub struct ChannelCallbacks(HashMap<String, Vec<ListenerCallback>>);
13+
14+
impl Default for ChannelCallbacks {
15+
fn default() -> Self {
16+
ChannelCallbacks(Default::default())
17+
}
18+
}
19+
20+
impl ChannelCallbacks {
21+
pub fn add_callback(&mut self, channel: String, callback: ListenerCallback) {
22+
match self.0.entry(channel) {
23+
Entry::Vacant(e) => {
24+
e.insert(vec![callback]);
25+
}
26+
Entry::Occupied(mut e) => {
27+
e.get_mut().push(callback);
28+
}
29+
};
30+
}
31+
32+
pub fn retrieve_channel_callbacks(&self, channel: String) -> Option<&Vec<ListenerCallback>> {
33+
self.0.get(&channel)
34+
}
35+
36+
pub fn clear_channel_callbacks(&mut self, channel: String) {
37+
self.0.remove(&channel);
38+
}
39+
40+
pub fn retrieve_all_channels(&self) -> Vec<&String> {
41+
self.0.keys().collect::<Vec<&String>>()
42+
}
43+
}
44+
45+
46+
#[derive(Clone, Debug)]
47+
pub struct ListenerNotification {
48+
pub process_id: i32,
49+
pub channel: String,
50+
pub payload: String,
51+
}
52+
53+
impl From::<Notification> for ListenerNotification {
54+
fn from(value: Notification) -> Self {
55+
ListenerNotification {
56+
process_id: value.process_id(),
57+
channel: String::from(value.channel()),
58+
payload: String::from(value.payload()),
59+
}
60+
}
61+
}
62+
63+
#[pyclass]
64+
pub struct ListenerNotificationMsg {
65+
process_id: i32,
66+
channel: String,
67+
payload: String,
68+
connection: Connection,
69+
}
70+
71+
#[pymethods]
72+
impl ListenerNotificationMsg {
73+
#[getter]
74+
fn process_id(&self) -> i32 {
75+
self.process_id
76+
}
77+
78+
#[getter]
79+
fn channel(&self) -> String {
80+
self.channel.clone()
81+
}
82+
83+
#[getter]
84+
fn payload(&self) -> String {
85+
self.payload.clone()
86+
}
87+
88+
#[getter]
89+
fn connection(&self) -> Connection {
90+
self.connection.clone()
91+
}
92+
}
93+
94+
impl ListenerNotificationMsg {
95+
pub fn new(value: ListenerNotification, conn: Connection) -> Self {
96+
ListenerNotificationMsg {
97+
process_id: value.process_id,
98+
channel: String::from(value.channel),
99+
payload: String::from(value.payload),
100+
connection: conn,
101+
}
102+
}
103+
}
104+
105+
pub struct ListenerCallback {
106+
task_locals: Option<TaskLocals>,
107+
callback: Py<PyAny>,
108+
}
109+
110+
impl ListenerCallback {
111+
pub fn new(
112+
task_locals: Option<TaskLocals>,
113+
callback: Py<PyAny>,
114+
) -> Self {
115+
ListenerCallback {
116+
task_locals,
117+
callback,
118+
}
119+
}
120+
121+
pub async fn call(
122+
&self,
123+
lister_notification: ListenerNotification,
124+
connection: Connection,
125+
) -> RustPSQLDriverPyResult<()> {
126+
let (callback, task_locals) = Python::with_gil(|py| {
127+
if let Some(task_locals) = &self.task_locals {
128+
return (self.callback.clone(), Some(task_locals.clone_ref(py)));
129+
}
130+
(self.callback.clone(), None)
131+
});
132+
133+
if let Some(task_locals) = task_locals {
134+
tokio_runtime().spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move {
135+
let future = Python::with_gil(|py| {
136+
let awaitable = callback.call1(
137+
py,
138+
(
139+
lister_notification.channel,
140+
lister_notification.payload,
141+
lister_notification.process_id,
142+
connection,
143+
)
144+
).unwrap();
145+
pyo3_async_runtimes::tokio::into_future(awaitable.into_bound(py)).unwrap()
146+
});
147+
future.await.unwrap();
148+
})).await?;
149+
};
150+
151+
Ok(())
152+
}
153+
}

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ fn psqlpy(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> {
2929
pymod.add_class::<driver::connection::Connection>()?;
3030
pymod.add_class::<driver::transaction::Transaction>()?;
3131
pymod.add_class::<driver::cursor::Cursor>()?;
32-
pymod.add_class::<driver::listener::Listener>()?;
33-
pymod.add_class::<driver::listener::ListenerNotificationMsg>()?;
32+
pymod.add_class::<driver::listener::core::Listener>()?;
33+
pymod.add_class::<driver::listener::structs::ListenerNotificationMsg>()?;
3434
pymod.add_class::<driver::transaction_options::IsolationLevel>()?;
3535
pymod.add_class::<driver::transaction_options::SynchronousCommit>()?;
3636
pymod.add_class::<driver::transaction_options::ReadVariant>()?;

0 commit comments

Comments
 (0)