@@ -10,10 +10,7 @@ use rivet_error::*;
1010use rivet_guard_core:: {
1111 WebSocketHandle ,
1212 custom_serve:: { CustomServeTrait , HibernationResult } ,
13- errors:: {
14- ServiceUnavailable , WebSocketServiceHibernate , WebSocketServiceTimeout ,
15- WebSocketServiceUnavailable ,
16- } ,
13+ errors:: { ServiceUnavailable , WebSocketServiceUnavailable } ,
1714 proxy_service:: { ResponseBody , is_ws_hibernate} ,
1815 request_context:: RequestContext ,
1916 websocket_handle:: WebSocketReceiver ,
@@ -32,10 +29,15 @@ use tokio_tungstenite::tungstenite::{
3229
3330use crate :: shared_state:: { InFlightRequestHandle , SharedState } ;
3431
32+ mod metrics;
33+ mod ping_task;
3534pub mod shared_state;
35+ mod tunnel_to_ws_task;
36+ mod ws_to_tunnel_task;
3637
3738const WEBSOCKET_OPEN_TIMEOUT : Duration = Duration :: from_secs ( 15 ) ;
3839const TUNNEL_ACK_TIMEOUT : Duration = Duration :: from_secs ( 5 ) ;
40+ const UPDATE_PING_INTERVAL : Duration = Duration :: from_secs ( 3 ) ;
3941
4042#[ derive( RivetError , Serialize , Deserialize ) ]
4143#[ error(
@@ -390,147 +392,47 @@ impl CustomServeTrait for PegboardGateway {
390392
391393 let ws_rx = client_ws. recv ( ) ;
392394
393- let ( tunnel_to_ws_abort_tx, mut tunnel_to_ws_abort_rx) = watch:: channel ( ( ) ) ;
394- let ( ws_to_tunnel_abort_tx, mut ws_to_tunnel_abort_rx) = watch:: channel ( ( ) ) ;
395-
396- // Spawn task to forward messages from tunnel to ws
397- let shared_state = self . shared_state . clone ( ) ;
398- let tunnel_to_ws = tokio:: spawn (
399- async move {
400- loop {
401- tokio:: select! {
402- res = msg_rx. recv( ) => {
403- if let Some ( msg) = res {
404- match msg {
405- protocol:: ToServerTunnelMessageKind :: ToServerWebSocketMessage ( ws_msg) => {
406- let msg = if ws_msg. binary {
407- Message :: Binary ( ws_msg. data. into( ) )
408- } else {
409- Message :: Text (
410- String :: from_utf8_lossy( & ws_msg. data) . into_owned( ) . into( ) ,
411- )
412- } ;
413- client_ws. send( msg) . await ?;
414- }
415- protocol:: ToServerTunnelMessageKind :: ToServerWebSocketMessageAck ( ack) => {
416- tracing:: debug!(
417- request_id=?Uuid :: from_bytes( request_id) ,
418- ack_index=?ack. index,
419- "received WebSocketMessageAck from runner"
420- ) ;
421- shared_state
422- . ack_pending_websocket_messages( request_id, ack. index)
423- . await ?;
424- }
425- protocol:: ToServerTunnelMessageKind :: ToServerWebSocketClose ( close) => {
426- tracing:: debug!( ?close, "server closed websocket" ) ;
427-
428- if can_hibernate && close. hibernate {
429- return Err ( WebSocketServiceHibernate . build( ) ) ;
430- } else {
431- // Successful closure
432- return Ok ( LifecycleResult :: ServerClose ( close) ) ;
433- }
434- }
435- _ => { }
436- }
437- } else {
438- tracing:: debug!( "tunnel sub closed" ) ;
439- return Err ( WebSocketServiceHibernate . build( ) ) ;
440- }
441- }
442- _ = stopped_sub. next( ) => {
443- tracing:: debug!( "actor stopped during websocket handler loop" ) ;
444-
445- if can_hibernate {
446- return Err ( WebSocketServiceHibernate . build( ) ) ;
447- } else {
448- return Err ( WebSocketServiceUnavailable . build( ) ) ;
449- }
450- }
451- _ = drop_rx. changed( ) => {
452- tracing:: warn!( "websocket message timeout" ) ;
453- return Err ( WebSocketServiceTimeout . build( ) ) ;
454- }
455- _ = tunnel_to_ws_abort_rx. changed( ) => {
456- tracing:: debug!( "task aborted" ) ;
457- return Ok ( LifecycleResult :: Aborted ) ;
458- }
459- }
460- }
461- }
462- . instrument ( tracing:: info_span!( "tunnel_to_ws_task" ) ) ,
463- ) ;
464-
465- // Spawn task to forward messages from ws to tunnel
466- let shared_state_clone = self . shared_state . clone ( ) ;
467- let ws_to_tunnel = tokio:: spawn (
468- async move {
469- let mut ws_rx = ws_rx. lock ( ) . await ;
470-
471- loop {
472- tokio:: select! {
473- res = ws_rx. try_next( ) => {
474- if let Some ( msg) = res? {
475- match msg {
476- Message :: Binary ( data) => {
477- let ws_message =
478- protocol:: ToClientTunnelMessageKind :: ToClientWebSocketMessage (
479- protocol:: ToClientWebSocketMessage {
480- // NOTE: This gets set in shared_state.ts
481- index: 0 ,
482- data: data. into( ) ,
483- binary: true ,
484- } ,
485- ) ;
486- shared_state_clone
487- . send_message( request_id, ws_message)
488- . await ?;
489- }
490- Message :: Text ( text) => {
491- let ws_message =
492- protocol:: ToClientTunnelMessageKind :: ToClientWebSocketMessage (
493- protocol:: ToClientWebSocketMessage {
494- // NOTE: This gets set in shared_state.ts
495- index: 0 ,
496- data: text. as_bytes( ) . to_vec( ) ,
497- binary: false ,
498- } ,
499- ) ;
500- shared_state_clone
501- . send_message( request_id, ws_message)
502- . await ?;
503- }
504- Message :: Close ( close) => {
505- return Ok ( LifecycleResult :: ClientClose ( close) ) ;
506- }
507- _ => { }
508- }
509- } else {
510- tracing:: debug!( "websocket stream closed" ) ;
511- return Ok ( LifecycleResult :: ClientClose ( None ) ) ;
512- }
513- }
514- _ = ws_to_tunnel_abort_rx. changed( ) => {
515- tracing:: debug!( "task aborted" ) ;
516- return Ok ( LifecycleResult :: Aborted ) ;
517- }
518- } ;
519- }
520- }
521- . instrument ( tracing:: info_span!( "ws_to_tunnel_task" ) ) ,
522- ) ;
395+ let ( tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch:: channel ( ( ) ) ;
396+ let ( ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch:: channel ( ( ) ) ;
397+ let ( ping_abort_tx, ping_abort_rx) = watch:: channel ( ( ) ) ;
398+
399+ let tunnel_to_ws = tokio:: spawn ( tunnel_to_ws_task:: task (
400+ self . shared_state . clone ( ) ,
401+ client_ws,
402+ request_id,
403+ stopped_sub,
404+ msg_rx,
405+ drop_rx,
406+ can_hibernate,
407+ tunnel_to_ws_abort_rx,
408+ ) ) ;
409+ let ws_to_tunnel = tokio:: spawn ( ws_to_tunnel_task:: task (
410+ self . shared_state . clone ( ) ,
411+ request_id,
412+ ws_rx,
413+ ws_to_tunnel_abort_rx,
414+ ) ) ;
415+ let ping = tokio:: spawn ( ping_task:: task (
416+ self . shared_state . clone ( ) ,
417+ request_id,
418+ ping_abort_rx,
419+ ) ) ;
420+
421+ let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx. clone ( ) ;
422+ let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx. clone ( ) ;
423+ let ping_abort_tx2 = ping_abort_tx. clone ( ) ;
523424
524425 // Wait for both tasks to complete
525- let ( tunnel_to_ws_res, ws_to_tunnel_res) = tokio:: join!(
426+ let ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res ) = tokio:: join!(
526427 async {
527428 let res = tunnel_to_ws. await ?;
528429
529430 // Abort other if not aborted
530431 if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
531432 tracing:: debug!( ?res, "tunnel to ws task completed, aborting counterpart" ) ;
532433
533- drop( ws_to_tunnel_abort_tx) ;
434+ let _ = ping_abort_tx. send( ( ) ) ;
435+ let _ = ws_to_tunnel_abort_tx. send( ( ) ) ;
534436 } else {
535437 tracing:: debug!( ?res, "tunnel to ws task completed" ) ;
536438 }
@@ -544,25 +446,42 @@ impl CustomServeTrait for PegboardGateway {
544446 if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
545447 tracing:: debug!( ?res, "ws to tunnel task completed, aborting counterpart" ) ;
546448
547- drop( tunnel_to_ws_abort_tx) ;
449+ let _ = ping_abort_tx2. send( ( ) ) ;
450+ let _ = tunnel_to_ws_abort_tx. send( ( ) ) ;
548451 } else {
549452 tracing:: debug!( ?res, "ws to tunnel task completed" ) ;
550453 }
551454
552455 res
553- }
456+ } ,
457+ async {
458+ let res = ping. await ?;
459+
460+ // Abort others if not aborted
461+ if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
462+ tracing:: debug!( ?res, "ping task completed, aborting others" ) ;
463+
464+ let _ = ws_to_tunnel_abort_tx2. send( ( ) ) ;
465+ let _ = tunnel_to_ws_abort_tx2. send( ( ) ) ;
466+ } else {
467+ tracing:: debug!( ?res, "ping task completed" ) ;
468+ }
469+
470+ res
471+ } ,
554472 ) ;
555473
556- // Determine single result from both tasks
557- let mut lifecycle_res = match ( tunnel_to_ws_res, ws_to_tunnel_res) {
474+ // Determine single result from all tasks
475+ let mut lifecycle_res = match ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res ) {
558476 // Prefer error
559- ( Err ( err) , _) => Err ( err) ,
560- ( _, Err ( err) ) => Err ( err) ,
477+ ( Err ( err) , _, _) => Err ( err) ,
478+ ( _, Err ( err) , _) => Err ( err) ,
479+ ( _, _, Err ( err) ) => Err ( err) ,
561480 // Prefer non aborted result if both succeed
562- ( Ok ( res) , Ok ( LifecycleResult :: Aborted ) ) => Ok ( res) ,
563- ( Ok ( LifecycleResult :: Aborted ) , Ok ( res) ) => Ok ( res) ,
564- // Prefer tunnel to ws if both succeed (unlikely case)
565- ( res, _) => res,
481+ ( Ok ( res) , Ok ( LifecycleResult :: Aborted ) , _ ) => Ok ( res) ,
482+ ( Ok ( LifecycleResult :: Aborted ) , Ok ( res) , _ ) => Ok ( res) ,
483+ // Unlikely case
484+ ( res, _, _ ) => res,
566485 } ;
567486
568487 // Send close frame to runner if not hibernating
0 commit comments