@@ -11,7 +11,7 @@ use rivet_error::*;
1111use rivet_guard_core:: {
1212 custom_serve:: { CustomServeTrait , HibernationResult } ,
1313 errors:: {
14- ServiceUnavailable , WebSocketServiceHibernate , WebSocketServiceTimeout ,
14+ ServiceUnavailable ,
1515 WebSocketServiceUnavailable ,
1616 } ,
1717 proxy_service:: { is_ws_hibernate, ResponseBody } ,
@@ -33,10 +33,15 @@ use tokio_tungstenite::tungstenite::{
3333
3434use crate :: shared_state:: { InFlightRequestHandle , SharedState } ;
3535
36+ mod metrics;
37+ mod ping_task;
3638pub mod shared_state;
39+ mod tunnel_to_ws_task;
40+ mod ws_to_tunnel_task;
3741
3842const WEBSOCKET_OPEN_TIMEOUT : Duration = Duration :: from_secs ( 15 ) ;
3943const TUNNEL_ACK_TIMEOUT : Duration = Duration :: from_secs ( 5 ) ;
44+ const UPDATE_PING_INTERVAL : Duration = Duration :: from_secs ( 3 ) ;
4045
4146#[ derive( RivetError , Serialize , Deserialize ) ]
4247#[ error(
@@ -391,143 +396,47 @@ impl CustomServeTrait for PegboardGateway {
391396
392397 let ws_rx = client_ws. recv ( ) ;
393398
394- let ( tunnel_to_ws_abort_tx, mut tunnel_to_ws_abort_rx) = watch:: channel ( ( ) ) ;
395- let ( ws_to_tunnel_abort_tx, mut ws_to_tunnel_abort_rx) = watch:: channel ( ( ) ) ;
396-
397- // Spawn task to forward messages from tunnel to ws
398- let shared_state = self . shared_state . clone ( ) ;
399- let tunnel_to_ws = tokio:: spawn (
400- async move {
401- loop {
402- tokio:: select! {
403- res = msg_rx. recv( ) => {
404- if let Some ( msg) = res {
405- match msg {
406- protocol:: ToServerTunnelMessageKind :: ToServerWebSocketMessage ( ws_msg) => {
407- let msg = if ws_msg. binary {
408- Message :: Binary ( ws_msg. data. into( ) )
409- } else {
410- Message :: Text (
411- String :: from_utf8_lossy( & ws_msg. data) . into_owned( ) . into( ) ,
412- )
413- } ;
414- client_ws. send( msg) . await ?;
415- }
416- protocol:: ToServerTunnelMessageKind :: ToServerWebSocketMessageAck ( ack) => {
417- tracing:: debug!(
418- request_id=?tunnel_id:: request_id_to_string( & request_id) ,
419- ack_index=?ack. index,
420- "received WebSocketMessageAck from runner"
421- ) ;
422- shared_state
423- . ack_pending_websocket_messages( request_id, ack. index)
424- . await ?;
425- }
426- protocol:: ToServerTunnelMessageKind :: ToServerWebSocketClose ( close) => {
427- tracing:: debug!( ?close, "server closed websocket" ) ;
428-
429- if can_hibernate && close. hibernate {
430- return Err ( WebSocketServiceHibernate . build( ) ) ;
431- } else {
432- // Successful closure
433- return Ok ( LifecycleResult :: ServerClose ( close) ) ;
434- }
435- }
436- _ => { }
437- }
438- } else {
439- tracing:: debug!( "tunnel sub closed" ) ;
440- return Err ( WebSocketServiceHibernate . build( ) ) ;
441- }
442- }
443- _ = stopped_sub. next( ) => {
444- tracing:: debug!( "actor stopped during websocket handler loop" ) ;
445-
446- if can_hibernate {
447- return Err ( WebSocketServiceHibernate . build( ) ) ;
448- } else {
449- return Err ( WebSocketServiceUnavailable . build( ) ) ;
450- }
451- }
452- _ = drop_rx. changed( ) => {
453- tracing:: warn!( "websocket message timeout" ) ;
454- return Err ( WebSocketServiceTimeout . build( ) ) ;
455- }
456- _ = tunnel_to_ws_abort_rx. changed( ) => {
457- tracing:: debug!( "task aborted" ) ;
458- return Ok ( LifecycleResult :: Aborted ) ;
459- }
460- }
461- }
462- }
463- . instrument ( tracing:: info_span!( "tunnel_to_ws_task" ) ) ,
464- ) ;
465-
466- // Spawn task to forward messages from ws to tunnel
467- let shared_state_clone = self . shared_state . clone ( ) ;
468- let ws_to_tunnel = tokio:: spawn (
469- async move {
470- let mut ws_rx = ws_rx. lock ( ) . await ;
471-
472- loop {
473- tokio:: select! {
474- res = ws_rx. try_next( ) => {
475- if let Some ( msg) = res? {
476- match msg {
477- Message :: Binary ( data) => {
478- let ws_message =
479- protocol:: ToClientTunnelMessageKind :: ToClientWebSocketMessage (
480- protocol:: ToClientWebSocketMessage {
481- data: data. into( ) ,
482- binary: true ,
483- } ,
484- ) ;
485- shared_state_clone
486- . send_message( request_id, ws_message)
487- . await ?;
488- }
489- Message :: Text ( text) => {
490- let ws_message =
491- protocol:: ToClientTunnelMessageKind :: ToClientWebSocketMessage (
492- protocol:: ToClientWebSocketMessage {
493- data: text. as_bytes( ) . to_vec( ) ,
494- binary: false ,
495- } ,
496- ) ;
497- shared_state_clone
498- . send_message( request_id, ws_message)
499- . await ?;
500- }
501- Message :: Close ( close) => {
502- return Ok ( LifecycleResult :: ClientClose ( close) ) ;
503- }
504- _ => { }
505- }
506- } else {
507- tracing:: debug!( "websocket stream closed" ) ;
508- return Ok ( LifecycleResult :: ClientClose ( None ) ) ;
509- }
510- }
511- _ = ws_to_tunnel_abort_rx. changed( ) => {
512- tracing:: debug!( "task aborted" ) ;
513- return Ok ( LifecycleResult :: Aborted ) ;
514- }
515- } ;
516- }
517- }
518- . instrument ( tracing:: info_span!( "ws_to_tunnel_task" ) ) ,
519- ) ;
399+ let ( tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch:: channel ( ( ) ) ;
400+ let ( ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch:: channel ( ( ) ) ;
401+ let ( ping_abort_tx, ping_abort_rx) = watch:: channel ( ( ) ) ;
402+
403+ let tunnel_to_ws = tokio:: spawn ( tunnel_to_ws_task:: task (
404+ self . shared_state . clone ( ) ,
405+ client_ws,
406+ request_id,
407+ stopped_sub,
408+ msg_rx,
409+ drop_rx,
410+ can_hibernate,
411+ tunnel_to_ws_abort_rx,
412+ ) ) ;
413+ let ws_to_tunnel = tokio:: spawn ( ws_to_tunnel_task:: task (
414+ self . shared_state . clone ( ) ,
415+ request_id,
416+ ws_rx,
417+ ws_to_tunnel_abort_rx,
418+ ) ) ;
419+ let ping = tokio:: spawn ( ping_task:: task (
420+ self . shared_state . clone ( ) ,
421+ request_id,
422+ ping_abort_rx,
423+ ) ) ;
424+
425+ let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx. clone ( ) ;
426+ let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx. clone ( ) ;
427+ let ping_abort_tx2 = ping_abort_tx. clone ( ) ;
520428
521429 // Wait for both tasks to complete
522- let ( tunnel_to_ws_res, ws_to_tunnel_res) = tokio:: join!(
430+ let ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res ) = tokio:: join!(
523431 async {
524432 let res = tunnel_to_ws. await ?;
525433
526434 // Abort other if not aborted
527435 if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
528436 tracing:: debug!( ?res, "tunnel to ws task completed, aborting counterpart" ) ;
529437
530- drop( ws_to_tunnel_abort_tx) ;
438+ let _ = ping_abort_tx. send( ( ) ) ;
439+ let _ = ws_to_tunnel_abort_tx. send( ( ) ) ;
531440 } else {
532441 tracing:: debug!( ?res, "tunnel to ws task completed" ) ;
533442 }
@@ -541,25 +450,42 @@ impl CustomServeTrait for PegboardGateway {
541450 if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
542451 tracing:: debug!( ?res, "ws to tunnel task completed, aborting counterpart" ) ;
543452
544- drop( tunnel_to_ws_abort_tx) ;
453+ let _ = ping_abort_tx2. send( ( ) ) ;
454+ let _ = tunnel_to_ws_abort_tx. send( ( ) ) ;
545455 } else {
546456 tracing:: debug!( ?res, "ws to tunnel task completed" ) ;
547457 }
548458
549459 res
550- }
460+ } ,
461+ async {
462+ let res = ping. await ?;
463+
464+ // Abort others if not aborted
465+ if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
466+ tracing:: debug!( ?res, "ping task completed, aborting others" ) ;
467+
468+ let _ = ws_to_tunnel_abort_tx2. send( ( ) ) ;
469+ let _ = tunnel_to_ws_abort_tx2. send( ( ) ) ;
470+ } else {
471+ tracing:: debug!( ?res, "ping task completed" ) ;
472+ }
473+
474+ res
475+ } ,
551476 ) ;
552477
553- // Determine single result from both tasks
554- let mut lifecycle_res = match ( tunnel_to_ws_res, ws_to_tunnel_res) {
478+ // Determine single result from all tasks
479+ let mut lifecycle_res = match ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res ) {
555480 // Prefer error
556- ( Err ( err) , _) => Err ( err) ,
557- ( _, Err ( err) ) => Err ( err) ,
481+ ( Err ( err) , _, _) => Err ( err) ,
482+ ( _, Err ( err) , _) => Err ( err) ,
483+ ( _, _, Err ( err) ) => Err ( err) ,
558484 // Prefer non aborted result if both succeed
559- ( Ok ( res) , Ok ( LifecycleResult :: Aborted ) ) => Ok ( res) ,
560- ( Ok ( LifecycleResult :: Aborted ) , Ok ( res) ) => Ok ( res) ,
561- // Prefer tunnel to ws if both succeed (unlikely case)
562- ( res, _) => res,
485+ ( Ok ( res) , Ok ( LifecycleResult :: Aborted ) , _ ) => Ok ( res) ,
486+ ( Ok ( LifecycleResult :: Aborted ) , Ok ( res) , _ ) => Ok ( res) ,
487+ // Unlikely case
488+ ( res, _, _ ) => res,
563489 } ;
564490
565491 // Send close frame to runner if not hibernating
0 commit comments