@@ -9,34 +9,36 @@ use pegboard::tunnel::id::{self as tunnel_id, RequestId};
99use rand:: Rng ;
1010use rivet_error:: * ;
1111use rivet_guard_core:: {
12+ WebSocketHandle ,
1213 custom_serve:: { CustomServeTrait , HibernationResult } ,
13- errors:: {
14- ServiceUnavailable , WebSocketServiceHibernate , WebSocketServiceTimeout ,
15- WebSocketServiceUnavailable ,
16- } ,
17- proxy_service:: { is_ws_hibernate, ResponseBody } ,
14+ errors:: { ServiceUnavailable , WebSocketServiceUnavailable } ,
15+ proxy_service:: { ResponseBody , is_ws_hibernate} ,
1816 request_context:: RequestContext ,
1917 websocket_handle:: WebSocketReceiver ,
20- WebSocketHandle ,
2118} ;
2219use rivet_runner_protocol as protocol;
2320use rivet_util:: serde:: HashableMap ;
2421use std:: { sync:: Arc , time:: Duration } ;
2522use tokio:: {
26- sync:: { watch , Mutex } ,
23+ sync:: { Mutex , watch } ,
2724 task:: JoinHandle ,
2825} ;
2926use tokio_tungstenite:: tungstenite:: {
30- protocol:: frame:: { coding:: CloseCode , CloseFrame } ,
3127 Message ,
28+ protocol:: frame:: { CloseFrame , coding:: CloseCode } ,
3229} ;
3330
3431use crate :: shared_state:: { InFlightRequestHandle , SharedState } ;
3532
33+ mod metrics;
34+ mod ping_task;
3635pub mod shared_state;
36+ mod tunnel_to_ws_task;
37+ mod ws_to_tunnel_task;
3738
3839const WEBSOCKET_OPEN_TIMEOUT : Duration = Duration :: from_secs ( 15 ) ;
3940const TUNNEL_ACK_TIMEOUT : Duration = Duration :: from_secs ( 5 ) ;
41+ const UPDATE_PING_INTERVAL : Duration = Duration :: from_secs ( 3 ) ;
4042
4143#[ derive( RivetError , Serialize , Deserialize ) ]
4244#[ error(
@@ -391,143 +393,47 @@ impl CustomServeTrait for PegboardGateway {
391393
392394 let ws_rx = client_ws. recv ( ) ;
393395
394- let ( tunnel_to_ws_abort_tx, mut tunnel_to_ws_abort_rx) = watch:: channel ( ( ) ) ;
395- let ( ws_to_tunnel_abort_tx, mut ws_to_tunnel_abort_rx) = watch:: channel ( ( ) ) ;
396-
397- // Spawn task to forward messages from tunnel to ws
398- let shared_state = self . shared_state . clone ( ) ;
399- let tunnel_to_ws = tokio:: spawn (
400- async move {
401- loop {
402- tokio:: select! {
403- res = msg_rx. recv( ) => {
404- if let Some ( msg) = res {
405- match msg {
406- protocol:: ToServerTunnelMessageKind :: ToServerWebSocketMessage ( ws_msg) => {
407- let msg = if ws_msg. binary {
408- Message :: Binary ( ws_msg. data. into( ) )
409- } else {
410- Message :: Text (
411- String :: from_utf8_lossy( & ws_msg. data) . into_owned( ) . into( ) ,
412- )
413- } ;
414- client_ws. send( msg) . await ?;
415- }
416- protocol:: ToServerTunnelMessageKind :: ToServerWebSocketMessageAck ( ack) => {
417- tracing:: debug!(
418- request_id=?tunnel_id:: request_id_to_string( & request_id) ,
419- ack_index=?ack. index,
420- "received WebSocketMessageAck from runner"
421- ) ;
422- shared_state
423- . ack_pending_websocket_messages( request_id, ack. index)
424- . await ?;
425- }
426- protocol:: ToServerTunnelMessageKind :: ToServerWebSocketClose ( close) => {
427- tracing:: debug!( ?close, "server closed websocket" ) ;
428-
429- if can_hibernate && close. hibernate {
430- return Err ( WebSocketServiceHibernate . build( ) ) ;
431- } else {
432- // Successful closure
433- return Ok ( LifecycleResult :: ServerClose ( close) ) ;
434- }
435- }
436- _ => { }
437- }
438- } else {
439- tracing:: debug!( "tunnel sub closed" ) ;
440- return Err ( WebSocketServiceHibernate . build( ) ) ;
441- }
442- }
443- _ = stopped_sub. next( ) => {
444- tracing:: debug!( "actor stopped during websocket handler loop" ) ;
445-
446- if can_hibernate {
447- return Err ( WebSocketServiceHibernate . build( ) ) ;
448- } else {
449- return Err ( WebSocketServiceUnavailable . build( ) ) ;
450- }
451- }
452- _ = drop_rx. changed( ) => {
453- tracing:: warn!( "websocket message timeout" ) ;
454- return Err ( WebSocketServiceTimeout . build( ) ) ;
455- }
456- _ = tunnel_to_ws_abort_rx. changed( ) => {
457- tracing:: debug!( "task aborted" ) ;
458- return Ok ( LifecycleResult :: Aborted ) ;
459- }
460- }
461- }
462- }
463- . instrument ( tracing:: info_span!( "tunnel_to_ws_task" ) ) ,
464- ) ;
465-
466- // Spawn task to forward messages from ws to tunnel
467- let shared_state_clone = self . shared_state . clone ( ) ;
468- let ws_to_tunnel = tokio:: spawn (
469- async move {
470- let mut ws_rx = ws_rx. lock ( ) . await ;
471-
472- loop {
473- tokio:: select! {
474- res = ws_rx. try_next( ) => {
475- if let Some ( msg) = res? {
476- match msg {
477- Message :: Binary ( data) => {
478- let ws_message =
479- protocol:: ToClientTunnelMessageKind :: ToClientWebSocketMessage (
480- protocol:: ToClientWebSocketMessage {
481- data: data. into( ) ,
482- binary: true ,
483- } ,
484- ) ;
485- shared_state_clone
486- . send_message( request_id, ws_message)
487- . await ?;
488- }
489- Message :: Text ( text) => {
490- let ws_message =
491- protocol:: ToClientTunnelMessageKind :: ToClientWebSocketMessage (
492- protocol:: ToClientWebSocketMessage {
493- data: text. as_bytes( ) . to_vec( ) ,
494- binary: false ,
495- } ,
496- ) ;
497- shared_state_clone
498- . send_message( request_id, ws_message)
499- . await ?;
500- }
501- Message :: Close ( close) => {
502- return Ok ( LifecycleResult :: ClientClose ( close) ) ;
503- }
504- _ => { }
505- }
506- } else {
507- tracing:: debug!( "websocket stream closed" ) ;
508- return Ok ( LifecycleResult :: ClientClose ( None ) ) ;
509- }
510- }
511- _ = ws_to_tunnel_abort_rx. changed( ) => {
512- tracing:: debug!( "task aborted" ) ;
513- return Ok ( LifecycleResult :: Aborted ) ;
514- }
515- } ;
516- }
517- }
518- . instrument ( tracing:: info_span!( "ws_to_tunnel_task" ) ) ,
519- ) ;
396+ let ( tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch:: channel ( ( ) ) ;
397+ let ( ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch:: channel ( ( ) ) ;
398+ let ( ping_abort_tx, ping_abort_rx) = watch:: channel ( ( ) ) ;
399+
400+ let tunnel_to_ws = tokio:: spawn ( tunnel_to_ws_task:: task (
401+ self . shared_state . clone ( ) ,
402+ client_ws,
403+ request_id,
404+ stopped_sub,
405+ msg_rx,
406+ drop_rx,
407+ can_hibernate,
408+ tunnel_to_ws_abort_rx,
409+ ) ) ;
410+ let ws_to_tunnel = tokio:: spawn ( ws_to_tunnel_task:: task (
411+ self . shared_state . clone ( ) ,
412+ request_id,
413+ ws_rx,
414+ ws_to_tunnel_abort_rx,
415+ ) ) ;
416+ let ping = tokio:: spawn ( ping_task:: task (
417+ self . shared_state . clone ( ) ,
418+ request_id,
419+ ping_abort_rx,
420+ ) ) ;
421+
422+ let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx. clone ( ) ;
423+ let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx. clone ( ) ;
424+ let ping_abort_tx2 = ping_abort_tx. clone ( ) ;
520425
521426 // Wait for both tasks to complete
522- let ( tunnel_to_ws_res, ws_to_tunnel_res) = tokio:: join!(
427+ let ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res ) = tokio:: join!(
523428 async {
524429 let res = tunnel_to_ws. await ?;
525430
526431 // Abort other if not aborted
527432 if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
528433 tracing:: debug!( ?res, "tunnel to ws task completed, aborting counterpart" ) ;
529434
530- drop( ws_to_tunnel_abort_tx) ;
435+ let _ = ping_abort_tx. send( ( ) ) ;
436+ let _ = ws_to_tunnel_abort_tx. send( ( ) ) ;
531437 } else {
532438 tracing:: debug!( ?res, "tunnel to ws task completed" ) ;
533439 }
@@ -541,25 +447,42 @@ impl CustomServeTrait for PegboardGateway {
541447 if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
542448 tracing:: debug!( ?res, "ws to tunnel task completed, aborting counterpart" ) ;
543449
544- drop( tunnel_to_ws_abort_tx) ;
450+ let _ = ping_abort_tx2. send( ( ) ) ;
451+ let _ = tunnel_to_ws_abort_tx. send( ( ) ) ;
545452 } else {
546453 tracing:: debug!( ?res, "ws to tunnel task completed" ) ;
547454 }
548455
549456 res
550- }
457+ } ,
458+ async {
459+ let res = ping. await ?;
460+
461+ // Abort others if not aborted
462+ if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
463+ tracing:: debug!( ?res, "ping task completed, aborting others" ) ;
464+
465+ let _ = ws_to_tunnel_abort_tx2. send( ( ) ) ;
466+ let _ = tunnel_to_ws_abort_tx2. send( ( ) ) ;
467+ } else {
468+ tracing:: debug!( ?res, "ping task completed" ) ;
469+ }
470+
471+ res
472+ } ,
551473 ) ;
552474
553- // Determine single result from both tasks
554- let mut lifecycle_res = match ( tunnel_to_ws_res, ws_to_tunnel_res) {
475+ // Determine single result from all tasks
476+ let mut lifecycle_res = match ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res ) {
555477 // Prefer error
556- ( Err ( err) , _) => Err ( err) ,
557- ( _, Err ( err) ) => Err ( err) ,
478+ ( Err ( err) , _, _) => Err ( err) ,
479+ ( _, Err ( err) , _) => Err ( err) ,
480+ ( _, _, Err ( err) ) => Err ( err) ,
558481 // Prefer non aborted result if both succeed
559- ( Ok ( res) , Ok ( LifecycleResult :: Aborted ) ) => Ok ( res) ,
560- ( Ok ( LifecycleResult :: Aborted ) , Ok ( res) ) => Ok ( res) ,
561- // Prefer tunnel to ws if both succeed (unlikely case)
562- ( res, _) => res,
482+ ( Ok ( res) , Ok ( LifecycleResult :: Aborted ) , _ ) => Ok ( res) ,
483+ ( Ok ( LifecycleResult :: Aborted ) , Ok ( res) , _ ) => Ok ( res) ,
484+ // Unlikely case
485+ ( res, _, _ ) => res,
563486 } ;
564487
565488 // Send close frame to runner if not hibernating
0 commit comments