Skip to content

Commit cebab09

Browse files
alfredotgalk888
andauthored
212/fix-infinite-hangs (#213)
Co-authored-by: alk888 <aleksei@g.com>
1 parent 20bee06 commit cebab09

File tree

5 files changed

+127
-33
lines changed

5 files changed

+127
-33
lines changed

src/client/dispatcher.rs

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use futures::Stream;
22
use rabbitmq_stream_protocol::Response;
3-
use std::sync::{atomic::AtomicU32, Arc};
3+
use std::sync::{
4+
atomic::{AtomicBool, AtomicU32, Ordering},
5+
Arc,
6+
};
47
use tracing::trace;
58

69
use dashmap::DashMap;
@@ -17,7 +20,7 @@ use super::{channel::ChannelReceiver, handler::MessageHandler};
1720
pub(crate) struct Dispatcher<T>(DispatcherState<T>);
1821

1922
pub(crate) struct DispatcherState<T> {
20-
requests: Arc<DashMap<u32, Sender<Response>>>,
23+
requests: Arc<RequestsMap>,
2124
correlation_id: Arc<AtomicU32>,
2225
handler: Arc<RwLock<Option<T>>>,
2326
}
@@ -32,13 +35,49 @@ impl<T> Clone for DispatcherState<T> {
3235
}
3336
}
3437

38+
struct RequestsMap {
39+
requests: DashMap<u32, Sender<Response>>,
40+
closed: AtomicBool,
41+
}
42+
43+
impl RequestsMap {
44+
fn new() -> RequestsMap {
45+
RequestsMap {
46+
requests: DashMap::new(),
47+
closed: AtomicBool::new(false),
48+
}
49+
}
50+
51+
fn insert(&self, correlation_id: u32, sender: Sender<Response>) -> bool {
52+
if self.closed.load(Ordering::Relaxed) {
53+
return false;
54+
}
55+
self.requests.insert(correlation_id, sender);
56+
true
57+
}
58+
59+
fn remove(&self, correlation_id: u32) -> Option<Sender<Response>> {
60+
self.requests.remove(&correlation_id).map(|r| r.1)
61+
}
62+
63+
fn close(&self) {
64+
self.closed.store(true, Ordering::Relaxed);
65+
self.requests.clear();
66+
}
67+
68+
#[cfg(test)]
69+
fn len(&self) -> usize {
70+
self.requests.len()
71+
}
72+
}
73+
3574
impl<T> Dispatcher<T>
3675
where
3776
T: MessageHandler,
3877
{
3978
pub fn new() -> Dispatcher<T> {
4079
Dispatcher(DispatcherState {
41-
requests: Arc::new(DashMap::new()),
80+
requests: Arc::new(RequestsMap::new()),
4281
correlation_id: Arc::new(AtomicU32::new(0)),
4382
handler: Arc::new(RwLock::new(None)),
4483
})
@@ -47,23 +86,25 @@ where
4786
#[cfg(test)]
4887
pub fn with_handler(handler: T) -> Dispatcher<T> {
4988
Dispatcher(DispatcherState {
50-
requests: Arc::new(DashMap::new()),
89+
requests: Arc::new(RequestsMap::new()),
5190
correlation_id: Arc::new(AtomicU32::new(0)),
5291
handler: Arc::new(RwLock::new(Some(handler))),
5392
})
5493
}
5594

56-
pub async fn response_channel(&self) -> (u32, Receiver<Response>) {
95+
pub fn response_channel(&self) -> Option<(u32, Receiver<Response>)> {
5796
let (tx, rx) = channel(1);
5897

5998
let correlation_id = self
6099
.0
61100
.correlation_id
62101
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
63102

64-
self.0.requests.insert(correlation_id, tx);
65-
66-
(correlation_id, rx)
103+
if self.0.requests.insert(correlation_id, tx) {
104+
Some((correlation_id, rx))
105+
} else {
106+
None
107+
}
67108
}
68109

69110
#[cfg(test)]
@@ -75,6 +116,7 @@ where
75116
let mut guard = self.0.handler.write().await;
76117
*guard = Some(handler);
77118
}
119+
78120
pub async fn start<R>(&self, stream: ChannelReceiver<R>)
79121
where
80122
R: Stream<Item = Result<Response, ClientError>> + Unpin + Send,
@@ -89,10 +131,10 @@ where
89131
T: MessageHandler,
90132
{
91133
pub async fn dispatch(&self, correlation_id: u32, response: Response) {
92-
let receiver = self.requests.remove(&correlation_id);
134+
let receiver = self.requests.remove(correlation_id);
93135

94136
if let Some(rcv) = receiver {
95-
let _ = rcv.1.send(response).await;
137+
let _ = rcv.send(response).await;
96138
}
97139
}
98140

@@ -103,6 +145,7 @@ where
103145
}
104146

105147
pub async fn close(self, error: Option<ClientError>) {
148+
self.requests.close();
106149
if let Some(handler) = self.handler.read().await.as_ref() {
107150
if let Some(err) = error {
108151
let _ = handler.handle_message(Some(Err(err))).await;
@@ -265,7 +308,7 @@ mod tests {
265308

266309
dispatcher.start(rx).await;
267310

268-
let (correlation_id, mut rx) = dispatcher.response_channel().await;
311+
let (correlation_id, mut rx) = dispatcher.response_channel().unwrap();
269312

270313
let req: Request = PeerPropertiesCommand::new(correlation_id, HashMap::new()).into();
271314

@@ -298,4 +341,19 @@ mod tests {
298341

299342
assert!(matches!(response, Some(..)));
300343
}
344+
345+
#[tokio::test]
346+
async fn should_reject_requests_after_closing() {
347+
let mock_source = MockIO::push();
348+
349+
let dispatcher = Dispatcher::with_handler(|_| async { Ok(()) });
350+
351+
let maybe_channel = dispatcher.response_channel();
352+
assert!(maybe_channel.is_some());
353+
354+
dispatcher.0.requests.close();
355+
356+
let maybe_channel = dispatcher.response_channel();
357+
assert!(maybe_channel.is_none());
358+
}
301359
}

src/client/mod.rs

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::convert::TryFrom;
22

3+
use std::ops::DerefMut;
34
use std::{
45
collections::HashMap,
56
io,
@@ -21,8 +22,8 @@ use std::{fs::File, io::BufReader, path::Path};
2122
use tokio::io::AsyncRead;
2223
use tokio::io::AsyncWrite;
2324
use tokio::io::ReadBuf;
25+
use tokio::sync::RwLock;
2426
use tokio::{net::TcpStream, sync::Notify};
25-
use tokio::{sync::RwLock, task::JoinHandle};
2627
use tokio_rustls::client::TlsStream;
2728
use tokio_rustls::rustls::ClientConfig;
2829
use tokio_rustls::{rustls, TlsConnector};
@@ -79,6 +80,7 @@ mod message;
7980
mod metadata;
8081
mod metrics;
8182
mod options;
83+
mod task;
8284

8385
#[cfg_attr(docsrs, doc(cfg(feature = "tokio-stream")))]
8486
#[pin_project(project = StreamProj)]
@@ -138,7 +140,7 @@ pub struct ClientState {
138140
heartbeat: u32,
139141
max_frame_size: u32,
140142
last_heatbeat: Instant,
141-
heartbeat_task: Option<JoinHandle<()>>,
143+
heartbeat_task: Option<task::TaskHandle>,
142144
}
143145

144146
#[async_trait::async_trait]
@@ -249,9 +251,7 @@ impl Client {
249251

250252
let mut state = self.state.write().await;
251253

252-
if let Some(heartbeat_task) = state.heartbeat_task.take() {
253-
heartbeat_task.abort();
254-
}
254+
state.heartbeat_task.take();
255255

256256
drop(state);
257257
self.channel.close().await
@@ -476,6 +476,9 @@ impl Client {
476476
})
477477
.await?;
478478

479+
// Start heartbeat task after connection is established
480+
self.start_hearbeat_task(self.state.write().await.deref_mut());
481+
479482
Ok(())
480483
}
481484

@@ -545,13 +548,15 @@ impl Client {
545548
T: FromResponse,
546549
M: FnOnce(u32) -> R,
547550
{
548-
let (correlation_id, mut receiver) = self.dispatcher.response_channel().await;
551+
let Some((correlation_id, mut receiver)) = self.dispatcher.response_channel() else {
552+
return Err(ClientError::ConnectionClosed);
553+
};
549554

550555
self.channel
551556
.send(msg_factory(correlation_id).into())
552557
.await?;
553558

554-
let response = receiver.recv().await.expect("It should contain a response");
559+
let response = receiver.recv().await.ok_or(ClientError::ConnectionClosed)?;
555560

556561
self.handle_response::<T>(response).await
557562
}
@@ -609,21 +614,8 @@ impl Client {
609614
heart_beat
610615
);
611616

612-
if let Some(task) = state.heartbeat_task.take() {
613-
task.abort();
614-
}
615-
616-
if heart_beat != 0 {
617-
let heartbeat_interval = (heart_beat / 2).max(1);
618-
let channel = self.channel.clone();
619-
let heartbeat_task = tokio::spawn(async move {
620-
loop {
621-
trace!("Sending heartbeat");
622-
let _ = channel.send(HeartBeatCommand::default().into()).await;
623-
tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await;
624-
}
625-
});
626-
state.heartbeat_task = Some(heartbeat_task);
617+
if state.heartbeat_task.take().is_some() {
618+
self.start_hearbeat_task(&mut state);
627619
}
628620

629621
drop(state);
@@ -636,6 +628,23 @@ impl Client {
636628
self.tune_notifier.notify_one();
637629
}
638630

631+
fn start_hearbeat_task(&self, state: &mut ClientState) {
632+
if state.heartbeat == 0 {
633+
return;
634+
}
635+
let heartbeat_interval = (state.heartbeat / 2).max(1);
636+
let channel = self.channel.clone();
637+
let heartbeat_task = tokio::spawn(async move {
638+
loop {
639+
trace!("Sending heartbeat");
640+
let _ = channel.send(HeartBeatCommand::default().into()).await;
641+
tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await;
642+
}
643+
})
644+
.into();
645+
state.heartbeat_task = Some(heartbeat_task);
646+
}
647+
639648
async fn handle_heart_beat_command(&self) {
640649
trace!("Received heartbeat");
641650
let mut state = self.state.write().await;

src/client/task.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
pub struct TaskHandle {
2+
task: tokio::task::JoinHandle<()>,
3+
}
4+
5+
impl From<tokio::task::JoinHandle<()>> for TaskHandle {
6+
fn from(task: tokio::task::JoinHandle<()>) -> Self {
7+
TaskHandle { task }
8+
}
9+
}
10+
11+
impl Drop for TaskHandle {
12+
fn drop(&mut self) {
13+
self.task.abort();
14+
}
15+
}

src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ pub enum ClientError {
1717
GenericError(#[from] Box<dyn std::error::Error + Send + Sync>),
1818
#[error("Client already closed")]
1919
AlreadyClosed,
20+
#[error("Connection closed")]
21+
ConnectionClosed,
2022
#[error(transparent)]
2123
Tls(#[from] tokio_rustls::rustls::Error),
2224
#[error("Request error: {0:?}")]

tests/integration/client_test.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::collections::HashMap;
22

33
use fake::{Fake, Faker};
4+
use rabbitmq_stream_protocol::commands::close::CloseRequest;
45
use tokio::sync::mpsc::channel;
56

67
use rabbitmq_stream_client::error::ClientError;
@@ -368,3 +369,12 @@ async fn client_publish() {
368369
delivery.messages.get(0).unwrap().data()
369370
);
370371
}
372+
373+
#[cfg(test)]
374+
#[tokio::test(flavor = "multi_thread")]
375+
async fn client_handle_unexpected_connection_interruption() {
376+
let mut options = ClientOptions::default();
377+
options.set_port(5672);
378+
let res = Client::connect(options).await;
379+
assert!(matches!(res, Err(ClientError::ConnectionClosed)));
380+
}

0 commit comments

Comments
 (0)