Skip to content

Commit 11b75bb

Browse files
committed
fix(tunnel): implement ping pong
1 parent 759df3e commit 11b75bb

File tree

14 files changed

+462
-236
lines changed

14 files changed

+462
-236
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/packages/pegboard-gateway/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ http-body-util.workspace = true
1515
# TODO: Doesn't match workspace version
1616
hyper = "1.6"
1717
hyper-tungstenite.workspace = true
18+
lazy_static.workspace = true
1819
pegboard.workspace = true
1920
rand.workspace = true
2021
rivet-error.workspace = true
2122
rivet-guard-core.workspace = true
23+
rivet-metrics.workspace = true
2224
rivet-runner-protocol.workspace = true
2325
rivet-util.workspace = true
2426
scc.workspace = true
25-
serde.workspace = true
2627
serde_json.workspace = true
28+
serde.workspace = true
2729
thiserror.workspace = true
2830
tokio-tungstenite.workspace = true
2931
tokio.workspace = true

engine/packages/pegboard-gateway/src/lib.rs

Lines changed: 65 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@ use rivet_error::*;
1010
use rivet_guard_core::{
1111
WebSocketHandle,
1212
custom_serve::{CustomServeTrait, HibernationResult},
13-
errors::{
14-
ServiceUnavailable, WebSocketServiceHibernate, WebSocketServiceTimeout,
15-
WebSocketServiceUnavailable,
16-
},
13+
errors::{ServiceUnavailable, WebSocketServiceUnavailable},
1714
proxy_service::{ResponseBody, is_ws_hibernate},
1815
request_context::RequestContext,
1916
websocket_handle::WebSocketReceiver,
@@ -32,10 +29,15 @@ use tokio_tungstenite::tungstenite::{
3229

3330
use crate::shared_state::{InFlightRequestHandle, SharedState};
3431

32+
mod metrics;
33+
mod ping_task;
3534
pub mod shared_state;
35+
mod tunnel_to_ws_task;
36+
mod ws_to_tunnel_task;
3637

3738
const WEBSOCKET_OPEN_TIMEOUT: Duration = Duration::from_secs(15);
3839
const TUNNEL_ACK_TIMEOUT: Duration = Duration::from_secs(5);
40+
const UPDATE_PING_INTERVAL: Duration = Duration::from_secs(3);
3941

4042
#[derive(RivetError, Serialize, Deserialize)]
4143
#[error(
@@ -390,147 +392,47 @@ impl CustomServeTrait for PegboardGateway {
390392

391393
let ws_rx = client_ws.recv();
392394

393-
let (tunnel_to_ws_abort_tx, mut tunnel_to_ws_abort_rx) = watch::channel(());
394-
let (ws_to_tunnel_abort_tx, mut ws_to_tunnel_abort_rx) = watch::channel(());
395-
396-
// Spawn task to forward messages from tunnel to ws
397-
let shared_state = self.shared_state.clone();
398-
let tunnel_to_ws = tokio::spawn(
399-
async move {
400-
loop {
401-
tokio::select! {
402-
res = msg_rx.recv() => {
403-
if let Some(msg) = res {
404-
match msg {
405-
protocol::ToServerTunnelMessageKind::ToServerWebSocketMessage(ws_msg) => {
406-
let msg = if ws_msg.binary {
407-
Message::Binary(ws_msg.data.into())
408-
} else {
409-
Message::Text(
410-
String::from_utf8_lossy(&ws_msg.data).into_owned().into(),
411-
)
412-
};
413-
client_ws.send(msg).await?;
414-
}
415-
protocol::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => {
416-
tracing::debug!(
417-
request_id=?Uuid::from_bytes(request_id),
418-
ack_index=?ack.index,
419-
"received WebSocketMessageAck from runner"
420-
);
421-
shared_state
422-
.ack_pending_websocket_messages(request_id, ack.index)
423-
.await?;
424-
}
425-
protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => {
426-
tracing::debug!(?close, "server closed websocket");
427-
428-
if can_hibernate && close.hibernate {
429-
return Err(WebSocketServiceHibernate.build());
430-
} else {
431-
// Successful closure
432-
return Ok(LifecycleResult::ServerClose(close));
433-
}
434-
}
435-
_ => {}
436-
}
437-
} else {
438-
tracing::debug!("tunnel sub closed");
439-
return Err(WebSocketServiceHibernate.build());
440-
}
441-
}
442-
_ = stopped_sub.next() => {
443-
tracing::debug!("actor stopped during websocket handler loop");
444-
445-
if can_hibernate {
446-
return Err(WebSocketServiceHibernate.build());
447-
} else {
448-
return Err(WebSocketServiceUnavailable.build());
449-
}
450-
}
451-
_ = drop_rx.changed() => {
452-
tracing::warn!("websocket message timeout");
453-
return Err(WebSocketServiceTimeout.build());
454-
}
455-
_ = tunnel_to_ws_abort_rx.changed() => {
456-
tracing::debug!("task aborted");
457-
return Ok(LifecycleResult::Aborted);
458-
}
459-
}
460-
}
461-
}
462-
.instrument(tracing::info_span!("tunnel_to_ws_task")),
463-
);
464-
465-
// Spawn task to forward messages from ws to tunnel
466-
let shared_state_clone = self.shared_state.clone();
467-
let ws_to_tunnel = tokio::spawn(
468-
async move {
469-
let mut ws_rx = ws_rx.lock().await;
470-
471-
loop {
472-
tokio::select! {
473-
res = ws_rx.try_next() => {
474-
if let Some(msg) = res? {
475-
match msg {
476-
Message::Binary(data) => {
477-
let ws_message =
478-
protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage(
479-
protocol::ToClientWebSocketMessage {
480-
// NOTE: This gets set in shared_state.ts
481-
index: 0,
482-
data: data.into(),
483-
binary: true,
484-
},
485-
);
486-
shared_state_clone
487-
.send_message(request_id, ws_message)
488-
.await?;
489-
}
490-
Message::Text(text) => {
491-
let ws_message =
492-
protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage(
493-
protocol::ToClientWebSocketMessage {
494-
// NOTE: This gets set in shared_state.ts
495-
index: 0,
496-
data: text.as_bytes().to_vec(),
497-
binary: false,
498-
},
499-
);
500-
shared_state_clone
501-
.send_message(request_id, ws_message)
502-
.await?;
503-
}
504-
Message::Close(close) => {
505-
return Ok(LifecycleResult::ClientClose(close));
506-
}
507-
_ => {}
508-
}
509-
} else {
510-
tracing::debug!("websocket stream closed");
511-
return Ok(LifecycleResult::ClientClose(None));
512-
}
513-
}
514-
_ = ws_to_tunnel_abort_rx.changed() => {
515-
tracing::debug!("task aborted");
516-
return Ok(LifecycleResult::Aborted);
517-
}
518-
};
519-
}
520-
}
521-
.instrument(tracing::info_span!("ws_to_tunnel_task")),
522-
);
395+
let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(());
396+
let (ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch::channel(());
397+
let (ping_abort_tx, ping_abort_rx) = watch::channel(());
398+
399+
let tunnel_to_ws = tokio::spawn(tunnel_to_ws_task::task(
400+
self.shared_state.clone(),
401+
client_ws,
402+
request_id,
403+
stopped_sub,
404+
msg_rx,
405+
drop_rx,
406+
can_hibernate,
407+
tunnel_to_ws_abort_rx,
408+
));
409+
let ws_to_tunnel = tokio::spawn(ws_to_tunnel_task::task(
410+
self.shared_state.clone(),
411+
request_id,
412+
ws_rx,
413+
ws_to_tunnel_abort_rx,
414+
));
415+
let ping = tokio::spawn(ping_task::task(
416+
self.shared_state.clone(),
417+
request_id,
418+
ping_abort_rx,
419+
));
420+
421+
let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx.clone();
422+
let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx.clone();
423+
let ping_abort_tx2 = ping_abort_tx.clone();
523424

524425
// Wait for both tasks to complete
525-
let (tunnel_to_ws_res, ws_to_tunnel_res) = tokio::join!(
426+
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) = tokio::join!(
526427
async {
527428
let res = tunnel_to_ws.await?;
528429

529430
// Abort other if not aborted
530431
if !matches!(res, Ok(LifecycleResult::Aborted)) {
531432
tracing::debug!(?res, "tunnel to ws task completed, aborting counterpart");
532433

533-
drop(ws_to_tunnel_abort_tx);
434+
let _ = ping_abort_tx.send(());
435+
let _ = ws_to_tunnel_abort_tx.send(());
534436
} else {
535437
tracing::debug!(?res, "tunnel to ws task completed");
536438
}
@@ -544,25 +446,42 @@ impl CustomServeTrait for PegboardGateway {
544446
if !matches!(res, Ok(LifecycleResult::Aborted)) {
545447
tracing::debug!(?res, "ws to tunnel task completed, aborting counterpart");
546448

547-
drop(tunnel_to_ws_abort_tx);
449+
let _ = ping_abort_tx2.send(());
450+
let _ = tunnel_to_ws_abort_tx.send(());
548451
} else {
549452
tracing::debug!(?res, "ws to tunnel task completed");
550453
}
551454

552455
res
553-
}
456+
},
457+
async {
458+
let res = ping.await?;
459+
460+
// Abort others if not aborted
461+
if !matches!(res, Ok(LifecycleResult::Aborted)) {
462+
tracing::debug!(?res, "ping task completed, aborting others");
463+
464+
let _ = ws_to_tunnel_abort_tx2.send(());
465+
let _ = tunnel_to_ws_abort_tx2.send(());
466+
} else {
467+
tracing::debug!(?res, "ping task completed");
468+
}
469+
470+
res
471+
},
554472
);
555473

556-
// Determine single result from both tasks
557-
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res) {
474+
// Determine single result from all tasks
475+
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) {
558476
// Prefer error
559-
(Err(err), _) => Err(err),
560-
(_, Err(err)) => Err(err),
477+
(Err(err), _, _) => Err(err),
478+
(_, Err(err), _) => Err(err),
479+
(_, _, Err(err)) => Err(err),
561480
// Prefer non aborted result if both succeed
562-
(Ok(res), Ok(LifecycleResult::Aborted)) => Ok(res),
563-
(Ok(LifecycleResult::Aborted), Ok(res)) => Ok(res),
564-
// Prefer tunnel to ws if both succeed (unlikely case)
565-
(res, _) => res,
481+
(Ok(res), Ok(LifecycleResult::Aborted), _) => Ok(res),
482+
(Ok(LifecycleResult::Aborted), Ok(res), _) => Ok(res),
483+
// Unlikely case
484+
(res, _, _) => res,
566485
};
567486

568487
// Send close frame to runner if not hibernating
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use rivet_metrics::{
2+
BUCKETS,
3+
otel::{global::*, metrics::*},
4+
};
5+
6+
lazy_static::lazy_static! {
7+
static ref METER: Meter = meter("rivet-gateway");
8+
9+
/// Has no expected attributes
10+
pub static ref TUNNEL_PING_DURATION: Histogram<f64> = METER.f64_histogram("rivet_gateway_tunnel_ping_duration")
11+
.with_description("RTT of messages from gateway to pegboard.")
12+
.with_boundaries(BUCKETS.to_vec())
13+
.build();
14+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use anyhow::Result;
2+
use rivet_runner_protocol as protocol;
3+
use tokio::sync::watch;
4+
5+
use super::{LifecycleResult, UPDATE_PING_INTERVAL};
6+
use crate::shared_state::SharedState;
7+
8+
pub async fn task(
9+
shared_state: SharedState,
10+
request_id: protocol::RequestId,
11+
mut ping_abort_rx: watch::Receiver<()>,
12+
) -> Result<LifecycleResult> {
13+
loop {
14+
tokio::select! {
15+
_ = tokio::time::sleep(UPDATE_PING_INTERVAL) => {}
16+
_ = ping_abort_rx.changed() => {
17+
return Ok(LifecycleResult::Aborted);
18+
}
19+
}
20+
21+
shared_state.send_and_check_ping(request_id).await?;
22+
}
23+
}

0 commit comments

Comments
 (0)