@@ -6,30 +6,27 @@ use gas::prelude::*;
66use http_body_util:: { BodyExt , Full } ;
77use hyper:: { Request , Response , StatusCode } ;
88use pegboard:: tunnel:: id:: { self as tunnel_id, RequestId } ;
9- use rand:: Rng ;
109use rivet_error:: * ;
1110use rivet_guard_core:: {
11+ WebSocketHandle ,
1212 custom_serve:: { CustomServeTrait , HibernationResult } ,
1313 errors:: { ServiceUnavailable , WebSocketServiceUnavailable } ,
14- proxy_service:: { is_ws_hibernate , ResponseBody } ,
14+ proxy_service:: { ResponseBody , is_ws_hibernate } ,
1515 request_context:: RequestContext ,
1616 websocket_handle:: WebSocketReceiver ,
17- WebSocketHandle ,
1817} ;
1918use rivet_runner_protocol as protocol;
2019use rivet_util:: serde:: HashableMap ;
2120use std:: { sync:: Arc , time:: Duration } ;
22- use tokio:: {
23- sync:: { watch, Mutex } ,
24- task:: JoinHandle ,
25- } ;
21+ use tokio:: sync:: { Mutex , watch} ;
2622use tokio_tungstenite:: tungstenite:: {
27- protocol:: frame:: { coding:: CloseCode , CloseFrame } ,
2823 Message ,
24+ protocol:: frame:: { CloseFrame , coding:: CloseCode } ,
2925} ;
3026
3127use crate :: shared_state:: { InFlightRequestHandle , SharedState } ;
3228
29+ mod keepalive_task;
3330mod metrics;
3431mod ping_task;
3532pub mod shared_state;
@@ -396,6 +393,7 @@ impl CustomServeTrait for PegboardGateway {
396393 let ( tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch:: channel ( ( ) ) ;
397394 let ( ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch:: channel ( ( ) ) ;
398395 let ( ping_abort_tx, ping_abort_rx) = watch:: channel ( ( ) ) ;
396+ let ( keepalive_abort_tx, keepalive_abort_rx) = watch:: channel ( ( ) ) ;
399397
400398 let tunnel_to_ws = tokio:: spawn ( tunnel_to_ws_task:: task (
401399 self . shared_state . clone ( ) ,
@@ -423,8 +421,14 @@ impl CustomServeTrait for PegboardGateway {
423421 let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx. clone ( ) ;
424422 let ping_abort_tx2 = ping_abort_tx. clone ( ) ;
425423
426- // Wait for both tasks to complete
427- let ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res) = tokio:: join!(
424+ // Clone variables needed for keepalive task
425+ let ctx_clone = self . ctx . clone ( ) ;
426+ let actor_id_clone = self . actor_id ;
427+ let gateway_id_clone = self . shared_state . gateway_id ( ) ;
428+ let request_id_clone = request_id;
429+
430+ // Wait for all tasks to complete
431+ let ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res) = tokio:: join!(
428432 async {
429433 let res = tunnel_to_ws. await ?;
430434
@@ -434,6 +438,7 @@ impl CustomServeTrait for PegboardGateway {
434438
435439 let _ = ping_abort_tx. send( ( ) ) ;
436440 let _ = ws_to_tunnel_abort_tx. send( ( ) ) ;
441+ let _ = keepalive_abort_tx. send( ( ) ) ;
437442 } else {
438443 tracing:: debug!( ?res, "tunnel to ws task completed" ) ;
439444 }
@@ -449,6 +454,7 @@ impl CustomServeTrait for PegboardGateway {
449454
450455 let _ = ping_abort_tx2. send( ( ) ) ;
451456 let _ = tunnel_to_ws_abort_tx. send( ( ) ) ;
457+ let _ = keepalive_abort_tx. send( ( ) ) ;
452458 } else {
453459 tracing:: debug!( ?res, "ws to tunnel task completed" ) ;
454460 }
@@ -464,25 +470,56 @@ impl CustomServeTrait for PegboardGateway {
464470
465471 let _ = ws_to_tunnel_abort_tx2. send( ( ) ) ;
466472 let _ = tunnel_to_ws_abort_tx2. send( ( ) ) ;
473+ let _ = keepalive_abort_tx. send( ( ) ) ;
467474 } else {
468475 tracing:: debug!( ?res, "ping task completed" ) ;
469476 }
470477
478+ res
479+ } ,
480+ async {
481+ if !can_hibernate {
482+ return Ok ( LifecycleResult :: Aborted ) ;
483+ }
484+
485+ let keepalive = tokio:: spawn( keepalive_task:: task(
486+ ctx_clone,
487+ actor_id_clone,
488+ gateway_id_clone,
489+ request_id_clone,
490+ keepalive_abort_rx,
491+ ) ) ;
492+
493+ let res = keepalive. await ?;
494+
495+ // Abort others if not aborted
496+ if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
497+ tracing:: debug!( ?res, "keepalive task completed, aborting others" ) ;
498+
499+ let _ = ws_to_tunnel_abort_tx2. send( ( ) ) ;
500+ let _ = tunnel_to_ws_abort_tx2. send( ( ) ) ;
501+ let _ = ping_abort_tx2. send( ( ) ) ;
502+ } else {
503+ tracing:: debug!( ?res, "keepalive task completed" ) ;
504+ }
505+
471506 res
472507 } ,
473508 ) ;
474509
475510 // Determine single result from all tasks
476- let mut lifecycle_res = match ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res) {
511+ let mut lifecycle_res = match ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res)
512+ {
477513 // Prefer error
478- ( Err ( err) , _, _) => Err ( err) ,
479- ( _, Err ( err) , _) => Err ( err) ,
480- ( _, _, Err ( err) ) => Err ( err) ,
481- // Prefer non aborted result if both succeed
482- ( Ok ( res) , Ok ( LifecycleResult :: Aborted ) , _) => Ok ( res) ,
483- ( Ok ( LifecycleResult :: Aborted ) , Ok ( res) , _) => Ok ( res) ,
514+ ( Err ( err) , _, _, _) => Err ( err) ,
515+ ( _, Err ( err) , _, _) => Err ( err) ,
516+ ( _, _, Err ( err) , _) => Err ( err) ,
517+ ( _, _, _, Err ( err) ) => Err ( err) ,
518+ // Prefer non aborted result if all succeed
519+ ( Ok ( res) , Ok ( LifecycleResult :: Aborted ) , _, _) => Ok ( res) ,
520+ ( Ok ( LifecycleResult :: Aborted ) , Ok ( res) , _, _) => Ok ( res) ,
484521 // Unlikely case
485- ( res, _, _) => res,
522+ ( res, _, _, _ ) => res,
486523 } ;
487524
488525 // Send close frame to runner if not hibernating
@@ -564,43 +601,19 @@ impl CustomServeTrait for PegboardGateway {
564601 }
565602
566603 // Start keepalive task
567- let ctx = self . ctx . clone ( ) ;
568- let actor_id = self . actor_id ;
569- let gateway_id = self . shared_state . gateway_id ( ) ;
570- let request_id = unique_request_id;
571- let keepalive_handle: JoinHandle < Result < ( ) > > = tokio:: spawn ( async move {
572- let mut ping_interval = tokio:: time:: interval ( Duration :: from_millis (
573- ( ctx. config ( )
574- . pegboard ( )
575- . hibernating_request_eligible_threshold ( )
576- / 2 )
577- . try_into ( ) ?,
578- ) ) ;
579- ping_interval. set_missed_tick_behavior ( tokio:: time:: MissedTickBehavior :: Skip ) ;
580-
581- // Discard the first tick since it fires immediately and we've already called this
582- // above
583- ping_interval. tick ( ) . await ;
584-
585- loop {
586- ping_interval. tick ( ) . await ;
587-
588- // Jitter sleep to prevent stampeding herds
589- let jitter = { rand:: thread_rng ( ) . gen_range ( 0 ..128 ) } ;
590- tokio:: time:: sleep ( Duration :: from_millis ( jitter) ) . await ;
591-
592- ctx. op ( pegboard:: ops:: actor:: hibernating_request:: upsert:: Input {
593- actor_id,
594- gateway_id,
595- request_id,
596- } )
597- . await ?;
598- }
599- } ) ;
604+ let ( keepalive_abort_tx, keepalive_abort_rx) = watch:: channel ( ( ) ) ;
605+ let keepalive_handle = tokio:: spawn ( keepalive_task:: task (
606+ self . ctx . clone ( ) ,
607+ self . actor_id ,
608+ self . shared_state . gateway_id ( ) ,
609+ unique_request_id,
610+ keepalive_abort_rx,
611+ ) ) ;
600612
601613 let res = self . handle_websocket_hibernation_inner ( client_ws) . await ;
602614
603- keepalive_handle. abort ( ) ;
615+ let _ = keepalive_abort_tx. send ( ( ) ) ;
616+ let _ = keepalive_handle. await ;
604617
605618 match & res {
606619 Ok ( HibernationResult :: Continue ) => { }
0 commit comments