Skip to content

Commit 5552b74

Browse files
committed
fix(pegboard-gateway): ping hibernating requests both during open hws connections and during hibernation
1 parent 4f192f3 commit 5552b74

File tree

2 files changed

+118
-52
lines changed

2 files changed

+118
-52
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
use anyhow::Result;
2+
use gas::prelude::*;
3+
use pegboard::tunnel::id::{GatewayId, RequestId};
4+
use rand::Rng;
5+
use std::time::Duration;
6+
use tokio::sync::watch;
7+
8+
use super::LifecycleResult;
9+
10+
/// Periodically pings writes keepalive in UDB. This is used to restore hibernating request IDs on
11+
/// next actor start.
12+
///
13+
///Only ran for hibernating requests.
14+
pub async fn task(
15+
ctx: StandaloneCtx,
16+
actor_id: Id,
17+
gateway_id: GatewayId,
18+
request_id: RequestId,
19+
mut keepalive_abort_rx: watch::Receiver<()>,
20+
) -> Result<LifecycleResult> {
21+
let mut ping_interval = tokio::time::interval(Duration::from_millis(
22+
(ctx.config()
23+
.pegboard()
24+
.hibernating_request_eligible_threshold()
25+
/ 2)
26+
.try_into()?,
27+
));
28+
ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
29+
30+
// Discard the first tick since it fires immediately and we've already called this
31+
// above
32+
ping_interval.tick().await;
33+
34+
loop {
35+
tokio::select! {
36+
_ = ping_interval.tick() => {}
37+
_ = keepalive_abort_rx.changed() => {
38+
return Ok(LifecycleResult::Aborted);
39+
}
40+
}
41+
42+
// Jitter sleep to prevent stampeding herds
43+
let jitter = { rand::thread_rng().gen_range(0..128) };
44+
tokio::time::sleep(Duration::from_millis(jitter)).await;
45+
46+
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
47+
actor_id,
48+
gateway_id,
49+
request_id,
50+
})
51+
.await?;
52+
}
53+
}

engine/packages/pegboard-gateway/src/lib.rs

Lines changed: 65 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,27 @@ use gas::prelude::*;
66
use http_body_util::{BodyExt, Full};
77
use hyper::{Request, Response, StatusCode};
88
use pegboard::tunnel::id::{self as tunnel_id, RequestId};
9-
use rand::Rng;
109
use rivet_error::*;
1110
use rivet_guard_core::{
11+
WebSocketHandle,
1212
custom_serve::{CustomServeTrait, HibernationResult},
1313
errors::{ServiceUnavailable, WebSocketServiceUnavailable},
14-
proxy_service::{is_ws_hibernate, ResponseBody},
14+
proxy_service::{ResponseBody, is_ws_hibernate},
1515
request_context::RequestContext,
1616
websocket_handle::WebSocketReceiver,
17-
WebSocketHandle,
1817
};
1918
use rivet_runner_protocol as protocol;
2019
use rivet_util::serde::HashableMap;
2120
use std::{sync::Arc, time::Duration};
22-
use tokio::{
23-
sync::{watch, Mutex},
24-
task::JoinHandle,
25-
};
21+
use tokio::sync::{Mutex, watch};
2622
use tokio_tungstenite::tungstenite::{
27-
protocol::frame::{coding::CloseCode, CloseFrame},
2823
Message,
24+
protocol::frame::{CloseFrame, coding::CloseCode},
2925
};
3026

3127
use crate::shared_state::{InFlightRequestHandle, SharedState};
3228

29+
mod keepalive_task;
3330
mod metrics;
3431
mod ping_task;
3532
pub mod shared_state;
@@ -396,6 +393,7 @@ impl CustomServeTrait for PegboardGateway {
396393
let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(());
397394
let (ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch::channel(());
398395
let (ping_abort_tx, ping_abort_rx) = watch::channel(());
396+
let (keepalive_abort_tx, keepalive_abort_rx) = watch::channel(());
399397

400398
let tunnel_to_ws = tokio::spawn(tunnel_to_ws_task::task(
401399
self.shared_state.clone(),
@@ -423,8 +421,14 @@ impl CustomServeTrait for PegboardGateway {
423421
let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx.clone();
424422
let ping_abort_tx2 = ping_abort_tx.clone();
425423

426-
// Wait for both tasks to complete
427-
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) = tokio::join!(
424+
// Clone variables needed for keepalive task
425+
let ctx_clone = self.ctx.clone();
426+
let actor_id_clone = self.actor_id;
427+
let gateway_id_clone = self.shared_state.gateway_id();
428+
let request_id_clone = request_id;
429+
430+
// Wait for all tasks to complete
431+
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res) = tokio::join!(
428432
async {
429433
let res = tunnel_to_ws.await?;
430434

@@ -434,6 +438,7 @@ impl CustomServeTrait for PegboardGateway {
434438

435439
let _ = ping_abort_tx.send(());
436440
let _ = ws_to_tunnel_abort_tx.send(());
441+
let _ = keepalive_abort_tx.send(());
437442
} else {
438443
tracing::debug!(?res, "tunnel to ws task completed");
439444
}
@@ -449,6 +454,7 @@ impl CustomServeTrait for PegboardGateway {
449454

450455
let _ = ping_abort_tx2.send(());
451456
let _ = tunnel_to_ws_abort_tx.send(());
457+
let _ = keepalive_abort_tx.send(());
452458
} else {
453459
tracing::debug!(?res, "ws to tunnel task completed");
454460
}
@@ -464,25 +470,56 @@ impl CustomServeTrait for PegboardGateway {
464470

465471
let _ = ws_to_tunnel_abort_tx2.send(());
466472
let _ = tunnel_to_ws_abort_tx2.send(());
473+
let _ = keepalive_abort_tx.send(());
467474
} else {
468475
tracing::debug!(?res, "ping task completed");
469476
}
470477

478+
res
479+
},
480+
async {
481+
if !can_hibernate {
482+
return Ok(LifecycleResult::Aborted);
483+
}
484+
485+
let keepalive = tokio::spawn(keepalive_task::task(
486+
ctx_clone,
487+
actor_id_clone,
488+
gateway_id_clone,
489+
request_id_clone,
490+
keepalive_abort_rx,
491+
));
492+
493+
let res = keepalive.await?;
494+
495+
// Abort others if not aborted
496+
if !matches!(res, Ok(LifecycleResult::Aborted)) {
497+
tracing::debug!(?res, "keepalive task completed, aborting others");
498+
499+
let _ = ws_to_tunnel_abort_tx2.send(());
500+
let _ = tunnel_to_ws_abort_tx2.send(());
501+
let _ = ping_abort_tx2.send(());
502+
} else {
503+
tracing::debug!(?res, "keepalive task completed");
504+
}
505+
471506
res
472507
},
473508
);
474509

475510
// Determine single result from all tasks
476-
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) {
511+
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res)
512+
{
477513
// Prefer error
478-
(Err(err), _, _) => Err(err),
479-
(_, Err(err), _) => Err(err),
480-
(_, _, Err(err)) => Err(err),
481-
// Prefer non aborted result if both succeed
482-
(Ok(res), Ok(LifecycleResult::Aborted), _) => Ok(res),
483-
(Ok(LifecycleResult::Aborted), Ok(res), _) => Ok(res),
514+
(Err(err), _, _, _) => Err(err),
515+
(_, Err(err), _, _) => Err(err),
516+
(_, _, Err(err), _) => Err(err),
517+
(_, _, _, Err(err)) => Err(err),
518+
// Prefer non aborted result if all succeed
519+
(Ok(res), Ok(LifecycleResult::Aborted), _, _) => Ok(res),
520+
(Ok(LifecycleResult::Aborted), Ok(res), _, _) => Ok(res),
484521
// Unlikely case
485-
(res, _, _) => res,
522+
(res, _, _, _) => res,
486523
};
487524

488525
// Send close frame to runner if not hibernating
@@ -564,43 +601,19 @@ impl CustomServeTrait for PegboardGateway {
564601
}
565602

566603
// Start keepalive task
567-
let ctx = self.ctx.clone();
568-
let actor_id = self.actor_id;
569-
let gateway_id = self.shared_state.gateway_id();
570-
let request_id = unique_request_id;
571-
let keepalive_handle: JoinHandle<Result<()>> = tokio::spawn(async move {
572-
let mut ping_interval = tokio::time::interval(Duration::from_millis(
573-
(ctx.config()
574-
.pegboard()
575-
.hibernating_request_eligible_threshold()
576-
/ 2)
577-
.try_into()?,
578-
));
579-
ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
580-
581-
// Discard the first tick since it fires immediately and we've already called this
582-
// above
583-
ping_interval.tick().await;
584-
585-
loop {
586-
ping_interval.tick().await;
587-
588-
// Jitter sleep to prevent stampeding herds
589-
let jitter = { rand::thread_rng().gen_range(0..128) };
590-
tokio::time::sleep(Duration::from_millis(jitter)).await;
591-
592-
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
593-
actor_id,
594-
gateway_id,
595-
request_id,
596-
})
597-
.await?;
598-
}
599-
});
604+
let (keepalive_abort_tx, keepalive_abort_rx) = watch::channel(());
605+
let keepalive_handle = tokio::spawn(keepalive_task::task(
606+
self.ctx.clone(),
607+
self.actor_id,
608+
self.shared_state.gateway_id(),
609+
unique_request_id,
610+
keepalive_abort_rx,
611+
));
600612

601613
let res = self.handle_websocket_hibernation_inner(client_ws).await;
602614

603-
keepalive_handle.abort();
615+
let _ = keepalive_abort_tx.send(());
616+
let _ = keepalive_handle.await;
604617

605618
match &res {
606619
Ok(HibernationResult::Continue) => {}

0 commit comments

Comments
 (0)