@@ -25,6 +25,7 @@ use rabbitmq_stream_protocol::{
25
25
delete:: Delete ,
26
26
delete_publisher:: DeletePublisherCommand ,
27
27
generic:: GenericResponse ,
28
+ heart_beat:: HeartBeatCommand ,
28
29
metadata:: MetadataCommand ,
29
30
open:: { OpenCommand , OpenResponse } ,
30
31
peer_properties:: { PeerPropertiesCommand , PeerPropertiesResponse } ,
@@ -41,6 +42,7 @@ use rabbitmq_stream_protocol::{
41
42
types:: PublishedMessage ,
42
43
FromResponse , Request , Response , ResponseCode , ResponseKind ,
43
44
} ;
45
+ use tokio_native_tls:: TlsStream ;
44
46
use tracing:: trace;
45
47
46
48
pub use self :: handler:: { MessageHandler , MessageResult } ;
@@ -58,14 +60,14 @@ use std::{
58
60
pin:: Pin ,
59
61
sync:: { atomic:: AtomicU64 , Arc } ,
60
62
task:: { Context , Poll } ,
63
+ time:: { Duration , Instant } ,
61
64
} ;
62
65
use std:: { future:: Future , sync:: atomic:: Ordering } ;
63
66
use tokio:: io:: AsyncRead ;
64
67
use tokio:: io:: AsyncWrite ;
65
68
use tokio:: io:: ReadBuf ;
66
- use tokio:: sync:: RwLock ;
67
69
use tokio:: { net:: TcpStream , sync:: Notify } ;
68
- use tokio_native_tls :: TlsStream ;
70
+ use tokio :: { sync :: RwLock , task :: JoinHandle } ;
69
71
use tokio_util:: codec:: Framed ;
70
72
71
73
#[ cfg_attr( docsrs, doc( cfg( feature = "tokio-stream" ) ) ) ]
@@ -125,6 +127,8 @@ pub struct ClientState {
125
127
handler : Option < Arc < dyn MessageHandler > > ,
126
128
heartbeat : u32 ,
127
129
max_frame_size : u32 ,
130
+ last_heatbeat : Instant ,
131
+ heartbeat_task : Option < JoinHandle < ( ) > > ,
128
132
}
129
133
130
134
#[ async_trait:: async_trait]
@@ -133,6 +137,7 @@ impl MessageHandler for Client {
133
137
match & item {
134
138
Some ( Ok ( response) ) => match response. kind_ref ( ) {
135
139
ResponseKind :: Tunes ( tune) => self . handle_tune_command ( tune) . await ,
140
+ ResponseKind :: Heartbeat ( _) => self . handle_heart_beat_command ( ) . await ,
136
141
_ => {
137
142
if let Some ( handler) = self . state . read ( ) . await . handler . as_ref ( ) {
138
143
let handler = handler. clone ( ) ;
@@ -188,6 +193,8 @@ impl Client {
188
193
handler : None ,
189
194
heartbeat : broker. heartbeat ,
190
195
max_frame_size : broker. max_frame_size ,
196
+ last_heatbeat : Instant :: now ( ) ,
197
+ heartbeat_task : None ,
191
198
} ;
192
199
let mut client = Client {
193
200
dispatcher,
@@ -228,6 +235,14 @@ impl Client {
228
235
CloseRequest :: new ( correlation_id, ResponseCode :: Ok , "Ok" . to_owned ( ) )
229
236
} )
230
237
. await ?;
238
+
239
+ let mut state = self . state . write ( ) . await ;
240
+
241
+ if let Some ( heartbeat_task) = state. heartbeat_task . take ( ) {
242
+ heartbeat_task. abort ( ) ;
243
+ }
244
+
245
+ drop ( state) ;
231
246
self . channel . close ( ) . await
232
247
}
233
248
pub async fn subscribe (
@@ -451,10 +466,10 @@ impl Client {
451
466
Ok ( ( ) )
452
467
}
453
468
454
- fn max_value ( & self , client : u32 , server : u32 ) -> u32 {
469
+ fn negotiate_value ( & self , client : u32 , server : u32 ) -> u32 {
455
470
match ( client, server) {
456
471
( client, server) if client == 0 || server == 0 => client. max ( server) ,
457
- ( client, server) => client. max ( server) ,
472
+ ( client, server) => client. min ( server) ,
458
473
}
459
474
}
460
475
@@ -543,11 +558,35 @@ impl Client {
543
558
544
559
async fn handle_tune_command ( & self , tunes : & TunesCommand ) {
545
560
let mut state = self . state . write ( ) . await ;
546
- state. heartbeat = self . max_value ( self . opts . heartbeat , tunes. heartbeat ) ;
547
- state. max_frame_size = self . max_value ( self . opts . max_frame_size , tunes. max_frame_size ) ;
561
+ state. heartbeat = self . negotiate_value ( self . opts . heartbeat , tunes. heartbeat ) ;
562
+ state. max_frame_size = self . negotiate_value ( self . opts . max_frame_size , tunes. max_frame_size ) ;
548
563
549
564
let heart_beat = state. heartbeat ;
550
565
let max_frame_size = state. max_frame_size ;
566
+
567
+ trace ! (
568
+ "Handling tune with frame size {} and heartbeat {}" ,
569
+ max_frame_size,
570
+ heart_beat
571
+ ) ;
572
+
573
+ if let Some ( task) = state. heartbeat_task . take ( ) {
574
+ task. abort ( ) ;
575
+ }
576
+
577
+ if heart_beat != 0 {
578
+ let heartbeat_interval = ( heart_beat / 2 ) . max ( 1 ) ;
579
+ let channel = self . channel . clone ( ) ;
580
+ let heartbeat_task = tokio:: spawn ( async move {
581
+ loop {
582
+ trace ! ( "Sending heartbeat" ) ;
583
+ let _ = channel. send ( HeartBeatCommand :: default ( ) . into ( ) ) . await ;
584
+ tokio:: time:: sleep ( Duration :: from_secs ( heartbeat_interval. into ( ) ) ) . await ;
585
+ }
586
+ } ) ;
587
+ state. heartbeat_task = Some ( heartbeat_task) ;
588
+ }
589
+
551
590
drop ( state) ;
552
591
553
592
let _ = self
@@ -557,4 +596,10 @@ impl Client {
557
596
558
597
self . tune_notifier . notify_one ( ) ;
559
598
}
599
+
600
+ async fn handle_heart_beat_command ( & self ) {
601
+ trace ! ( "Received heartbeat" ) ;
602
+ let mut state = self . state . write ( ) . await ;
603
+ state. last_heatbeat = Instant :: now ( ) ;
604
+ }
560
605
}
0 commit comments