1
1
use futures:: Stream ;
2
2
use rabbitmq_stream_protocol:: Response ;
3
- use std:: sync:: { atomic:: AtomicU32 , Arc } ;
3
+ use std:: sync:: {
4
+ atomic:: { AtomicBool , AtomicU32 , Ordering } ,
5
+ Arc ,
6
+ } ;
4
7
use tracing:: trace;
5
8
6
9
use dashmap:: DashMap ;
@@ -17,7 +20,7 @@ use super::{channel::ChannelReceiver, handler::MessageHandler};
17
20
pub ( crate ) struct Dispatcher < T > ( DispatcherState < T > ) ;
18
21
19
22
pub ( crate ) struct DispatcherState < T > {
20
- requests : Arc < DashMap < u32 , Sender < Response > > > ,
23
+ requests : Arc < RequestsMap > ,
21
24
correlation_id : Arc < AtomicU32 > ,
22
25
handler : Arc < RwLock < Option < T > > > ,
23
26
}
@@ -32,13 +35,49 @@ impl<T> Clone for DispatcherState<T> {
32
35
}
33
36
}
34
37
38
+ struct RequestsMap {
39
+ requests : DashMap < u32 , Sender < Response > > ,
40
+ closed : AtomicBool ,
41
+ }
42
+
43
+ impl RequestsMap {
44
+ fn new ( ) -> RequestsMap {
45
+ RequestsMap {
46
+ requests : DashMap :: new ( ) ,
47
+ closed : AtomicBool :: new ( false ) ,
48
+ }
49
+ }
50
+
51
+ fn insert ( & self , correlation_id : u32 , sender : Sender < Response > ) -> bool {
52
+ if self . closed . load ( Ordering :: Relaxed ) {
53
+ return false ;
54
+ }
55
+ self . requests . insert ( correlation_id, sender) ;
56
+ true
57
+ }
58
+
59
+ fn remove ( & self , correlation_id : u32 ) -> Option < Sender < Response > > {
60
+ self . requests . remove ( & correlation_id) . map ( |r| r. 1 )
61
+ }
62
+
63
+ fn close ( & self ) {
64
+ self . closed . store ( true , Ordering :: Relaxed ) ;
65
+ self . requests . clear ( ) ;
66
+ }
67
+
68
+ #[ cfg( test) ]
69
+ fn len ( & self ) -> usize {
70
+ self . requests . len ( )
71
+ }
72
+ }
73
+
35
74
impl < T > Dispatcher < T >
36
75
where
37
76
T : MessageHandler ,
38
77
{
39
78
pub fn new ( ) -> Dispatcher < T > {
40
79
Dispatcher ( DispatcherState {
41
- requests : Arc :: new ( DashMap :: new ( ) ) ,
80
+ requests : Arc :: new ( RequestsMap :: new ( ) ) ,
42
81
correlation_id : Arc :: new ( AtomicU32 :: new ( 0 ) ) ,
43
82
handler : Arc :: new ( RwLock :: new ( None ) ) ,
44
83
} )
@@ -47,23 +86,25 @@ where
47
86
#[ cfg( test) ]
48
87
pub fn with_handler ( handler : T ) -> Dispatcher < T > {
49
88
Dispatcher ( DispatcherState {
50
- requests : Arc :: new ( DashMap :: new ( ) ) ,
89
+ requests : Arc :: new ( RequestsMap :: new ( ) ) ,
51
90
correlation_id : Arc :: new ( AtomicU32 :: new ( 0 ) ) ,
52
91
handler : Arc :: new ( RwLock :: new ( Some ( handler) ) ) ,
53
92
} )
54
93
}
55
94
56
- pub async fn response_channel ( & self ) -> ( u32 , Receiver < Response > ) {
95
+ pub fn response_channel ( & self ) -> Option < ( u32 , Receiver < Response > ) > {
57
96
let ( tx, rx) = channel ( 1 ) ;
58
97
59
98
let correlation_id = self
60
99
. 0
61
100
. correlation_id
62
101
. fetch_add ( 1 , std:: sync:: atomic:: Ordering :: Relaxed ) ;
63
102
64
- self . 0 . requests . insert ( correlation_id, tx) ;
65
-
66
- ( correlation_id, rx)
103
+ if self . 0 . requests . insert ( correlation_id, tx) {
104
+ Some ( ( correlation_id, rx) )
105
+ } else {
106
+ None
107
+ }
67
108
}
68
109
69
110
#[ cfg( test) ]
75
116
let mut guard = self . 0 . handler . write ( ) . await ;
76
117
* guard = Some ( handler) ;
77
118
}
119
+
78
120
pub async fn start < R > ( & self , stream : ChannelReceiver < R > )
79
121
where
80
122
R : Stream < Item = Result < Response , ClientError > > + Unpin + Send ,
@@ -89,10 +131,10 @@ where
89
131
T : MessageHandler ,
90
132
{
91
133
pub async fn dispatch ( & self , correlation_id : u32 , response : Response ) {
92
- let receiver = self . requests . remove ( & correlation_id) ;
134
+ let receiver = self . requests . remove ( correlation_id) ;
93
135
94
136
if let Some ( rcv) = receiver {
95
- let _ = rcv. 1 . send ( response) . await ;
137
+ let _ = rcv. send ( response) . await ;
96
138
}
97
139
}
98
140
@@ -103,6 +145,7 @@ where
103
145
}
104
146
105
147
pub async fn close ( self , error : Option < ClientError > ) {
148
+ self . requests . close ( ) ;
106
149
if let Some ( handler) = self . handler . read ( ) . await . as_ref ( ) {
107
150
if let Some ( err) = error {
108
151
let _ = handler. handle_message ( Some ( Err ( err) ) ) . await ;
@@ -265,7 +308,7 @@ mod tests {
265
308
266
309
dispatcher. start ( rx) . await ;
267
310
268
- let ( correlation_id, mut rx) = dispatcher. response_channel ( ) . await ;
311
+ let ( correlation_id, mut rx) = dispatcher. response_channel ( ) . unwrap ( ) ;
269
312
270
313
let req: Request = PeerPropertiesCommand :: new ( correlation_id, HashMap :: new ( ) ) . into ( ) ;
271
314
@@ -298,4 +341,19 @@ mod tests {
298
341
299
342
assert ! ( matches!( response, Some ( ..) ) ) ;
300
343
}
344
+
345
+ #[ tokio:: test]
346
+ async fn should_reject_requests_after_closing ( ) {
347
+ let mock_source = MockIO :: push ( ) ;
348
+
349
+ let dispatcher = Dispatcher :: with_handler ( |_| async { Ok ( ( ) ) } ) ;
350
+
351
+ let maybe_channel = dispatcher. response_channel ( ) ;
352
+ assert ! ( maybe_channel. is_some( ) ) ;
353
+
354
+ dispatcher. 0 . requests . close ( ) ;
355
+
356
+ let maybe_channel = dispatcher. response_channel ( ) ;
357
+ assert ! ( maybe_channel. is_none( ) ) ;
358
+ }
301
359
}
0 commit comments