Skip to content

Commit 73e19ae

Browse files
MasterPtatoNathanFlurry
authored andcommitted
fix(tunnel): implement ping pong
1 parent e47241f commit 73e19ae

File tree

14 files changed

+454
-204
lines changed

14 files changed

+454
-204
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/packages/pegboard-gateway/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ http-body-util.workspace = true
1515
# TODO: Doesn't match workspace version
1616
hyper = "1.6"
1717
hyper-tungstenite.workspace = true
18+
lazy_static.workspace = true
1819
pegboard.workspace = true
1920
rand.workspace = true
2021
rivet-error.workspace = true
2122
rivet-guard-core.workspace = true
23+
rivet-metrics.workspace = true
2224
rivet-runner-protocol.workspace = true
2325
rivet-util.workspace = true
2426
scc.workspace = true
25-
serde.workspace = true
2627
serde_json.workspace = true
28+
serde.workspace = true
2729
thiserror.workspace = true
2830
tokio-tungstenite.workspace = true
2931
tokio.workspace = true

engine/packages/pegboard-gateway/src/lib.rs

Lines changed: 65 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use rivet_error::*;
1111
use rivet_guard_core::{
1212
custom_serve::{CustomServeTrait, HibernationResult},
1313
errors::{
14-
ServiceUnavailable, WebSocketServiceHibernate, WebSocketServiceTimeout,
14+
ServiceUnavailable,
1515
WebSocketServiceUnavailable,
1616
},
1717
proxy_service::{is_ws_hibernate, ResponseBody},
@@ -33,10 +33,15 @@ use tokio_tungstenite::tungstenite::{
3333

3434
use crate::shared_state::{InFlightRequestHandle, SharedState};
3535

36+
mod metrics;
37+
mod ping_task;
3638
pub mod shared_state;
39+
mod tunnel_to_ws_task;
40+
mod ws_to_tunnel_task;
3741

3842
const WEBSOCKET_OPEN_TIMEOUT: Duration = Duration::from_secs(15);
3943
const TUNNEL_ACK_TIMEOUT: Duration = Duration::from_secs(5);
44+
const UPDATE_PING_INTERVAL: Duration = Duration::from_secs(3);
4045

4146
#[derive(RivetError, Serialize, Deserialize)]
4247
#[error(
@@ -391,143 +396,47 @@ impl CustomServeTrait for PegboardGateway {
391396

392397
let ws_rx = client_ws.recv();
393398

394-
let (tunnel_to_ws_abort_tx, mut tunnel_to_ws_abort_rx) = watch::channel(());
395-
let (ws_to_tunnel_abort_tx, mut ws_to_tunnel_abort_rx) = watch::channel(());
396-
397-
// Spawn task to forward messages from tunnel to ws
398-
let shared_state = self.shared_state.clone();
399-
let tunnel_to_ws = tokio::spawn(
400-
async move {
401-
loop {
402-
tokio::select! {
403-
res = msg_rx.recv() => {
404-
if let Some(msg) = res {
405-
match msg {
406-
protocol::ToServerTunnelMessageKind::ToServerWebSocketMessage(ws_msg) => {
407-
let msg = if ws_msg.binary {
408-
Message::Binary(ws_msg.data.into())
409-
} else {
410-
Message::Text(
411-
String::from_utf8_lossy(&ws_msg.data).into_owned().into(),
412-
)
413-
};
414-
client_ws.send(msg).await?;
415-
}
416-
protocol::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => {
417-
tracing::debug!(
418-
request_id=?tunnel_id::request_id_to_string(&request_id),
419-
ack_index=?ack.index,
420-
"received WebSocketMessageAck from runner"
421-
);
422-
shared_state
423-
.ack_pending_websocket_messages(request_id, ack.index)
424-
.await?;
425-
}
426-
protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => {
427-
tracing::debug!(?close, "server closed websocket");
428-
429-
if can_hibernate && close.hibernate {
430-
return Err(WebSocketServiceHibernate.build());
431-
} else {
432-
// Successful closure
433-
return Ok(LifecycleResult::ServerClose(close));
434-
}
435-
}
436-
_ => {}
437-
}
438-
} else {
439-
tracing::debug!("tunnel sub closed");
440-
return Err(WebSocketServiceHibernate.build());
441-
}
442-
}
443-
_ = stopped_sub.next() => {
444-
tracing::debug!("actor stopped during websocket handler loop");
445-
446-
if can_hibernate {
447-
return Err(WebSocketServiceHibernate.build());
448-
} else {
449-
return Err(WebSocketServiceUnavailable.build());
450-
}
451-
}
452-
_ = drop_rx.changed() => {
453-
tracing::warn!("websocket message timeout");
454-
return Err(WebSocketServiceTimeout.build());
455-
}
456-
_ = tunnel_to_ws_abort_rx.changed() => {
457-
tracing::debug!("task aborted");
458-
return Ok(LifecycleResult::Aborted);
459-
}
460-
}
461-
}
462-
}
463-
.instrument(tracing::info_span!("tunnel_to_ws_task")),
464-
);
465-
466-
// Spawn task to forward messages from ws to tunnel
467-
let shared_state_clone = self.shared_state.clone();
468-
let ws_to_tunnel = tokio::spawn(
469-
async move {
470-
let mut ws_rx = ws_rx.lock().await;
471-
472-
loop {
473-
tokio::select! {
474-
res = ws_rx.try_next() => {
475-
if let Some(msg) = res? {
476-
match msg {
477-
Message::Binary(data) => {
478-
let ws_message =
479-
protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage(
480-
protocol::ToClientWebSocketMessage {
481-
data: data.into(),
482-
binary: true,
483-
},
484-
);
485-
shared_state_clone
486-
.send_message(request_id, ws_message)
487-
.await?;
488-
}
489-
Message::Text(text) => {
490-
let ws_message =
491-
protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage(
492-
protocol::ToClientWebSocketMessage {
493-
data: text.as_bytes().to_vec(),
494-
binary: false,
495-
},
496-
);
497-
shared_state_clone
498-
.send_message(request_id, ws_message)
499-
.await?;
500-
}
501-
Message::Close(close) => {
502-
return Ok(LifecycleResult::ClientClose(close));
503-
}
504-
_ => {}
505-
}
506-
} else {
507-
tracing::debug!("websocket stream closed");
508-
return Ok(LifecycleResult::ClientClose(None));
509-
}
510-
}
511-
_ = ws_to_tunnel_abort_rx.changed() => {
512-
tracing::debug!("task aborted");
513-
return Ok(LifecycleResult::Aborted);
514-
}
515-
};
516-
}
517-
}
518-
.instrument(tracing::info_span!("ws_to_tunnel_task")),
519-
);
399+
let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(());
400+
let (ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch::channel(());
401+
let (ping_abort_tx, ping_abort_rx) = watch::channel(());
402+
403+
let tunnel_to_ws = tokio::spawn(tunnel_to_ws_task::task(
404+
self.shared_state.clone(),
405+
client_ws,
406+
request_id,
407+
stopped_sub,
408+
msg_rx,
409+
drop_rx,
410+
can_hibernate,
411+
tunnel_to_ws_abort_rx,
412+
));
413+
let ws_to_tunnel = tokio::spawn(ws_to_tunnel_task::task(
414+
self.shared_state.clone(),
415+
request_id,
416+
ws_rx,
417+
ws_to_tunnel_abort_rx,
418+
));
419+
let ping = tokio::spawn(ping_task::task(
420+
self.shared_state.clone(),
421+
request_id,
422+
ping_abort_rx,
423+
));
424+
425+
let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx.clone();
426+
let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx.clone();
427+
let ping_abort_tx2 = ping_abort_tx.clone();
520428

521429
// Wait for both tasks to complete
522-
let (tunnel_to_ws_res, ws_to_tunnel_res) = tokio::join!(
430+
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) = tokio::join!(
523431
async {
524432
let res = tunnel_to_ws.await?;
525433

526434
// Abort other if not aborted
527435
if !matches!(res, Ok(LifecycleResult::Aborted)) {
528436
tracing::debug!(?res, "tunnel to ws task completed, aborting counterpart");
529437

530-
drop(ws_to_tunnel_abort_tx);
438+
let _ = ping_abort_tx.send(());
439+
let _ = ws_to_tunnel_abort_tx.send(());
531440
} else {
532441
tracing::debug!(?res, "tunnel to ws task completed");
533442
}
@@ -541,25 +450,42 @@ impl CustomServeTrait for PegboardGateway {
541450
if !matches!(res, Ok(LifecycleResult::Aborted)) {
542451
tracing::debug!(?res, "ws to tunnel task completed, aborting counterpart");
543452

544-
drop(tunnel_to_ws_abort_tx);
453+
let _ = ping_abort_tx2.send(());
454+
let _ = tunnel_to_ws_abort_tx.send(());
545455
} else {
546456
tracing::debug!(?res, "ws to tunnel task completed");
547457
}
548458

549459
res
550-
}
460+
},
461+
async {
462+
let res = ping.await?;
463+
464+
// Abort others if not aborted
465+
if !matches!(res, Ok(LifecycleResult::Aborted)) {
466+
tracing::debug!(?res, "ping task completed, aborting others");
467+
468+
let _ = ws_to_tunnel_abort_tx2.send(());
469+
let _ = tunnel_to_ws_abort_tx2.send(());
470+
} else {
471+
tracing::debug!(?res, "ping task completed");
472+
}
473+
474+
res
475+
},
551476
);
552477

553-
// Determine single result from both tasks
554-
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res) {
478+
// Determine single result from all tasks
479+
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) {
555480
// Prefer error
556-
(Err(err), _) => Err(err),
557-
(_, Err(err)) => Err(err),
481+
(Err(err), _, _) => Err(err),
482+
(_, Err(err), _) => Err(err),
483+
(_, _, Err(err)) => Err(err),
558484
// Prefer non aborted result if both succeed
559-
(Ok(res), Ok(LifecycleResult::Aborted)) => Ok(res),
560-
(Ok(LifecycleResult::Aborted), Ok(res)) => Ok(res),
561-
// Prefer tunnel to ws if both succeed (unlikely case)
562-
(res, _) => res,
485+
(Ok(res), Ok(LifecycleResult::Aborted), _) => Ok(res),
486+
(Ok(LifecycleResult::Aborted), Ok(res), _) => Ok(res),
487+
// Unlikely case
488+
(res, _, _) => res,
563489
};
564490

565491
// Send close frame to runner if not hibernating
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use rivet_metrics::{
2+
BUCKETS,
3+
otel::{global::*, metrics::*},
4+
};
5+
6+
lazy_static::lazy_static! {
7+
static ref METER: Meter = meter("rivet-gateway");
8+
9+
/// Has no expected attributes
10+
pub static ref TUNNEL_PING_DURATION: Histogram<f64> = METER.f64_histogram("rivet_gateway_tunnel_ping_duration")
11+
.with_description("RTT of messages from gateway to pegboard.")
12+
.with_boundaries(BUCKETS.to_vec())
13+
.build();
14+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use anyhow::Result;
2+
use rivet_runner_protocol as protocol;
3+
use tokio::sync::watch;
4+
5+
use super::{LifecycleResult, UPDATE_PING_INTERVAL};
6+
use crate::shared_state::SharedState;
7+
8+
pub async fn task(
9+
shared_state: SharedState,
10+
request_id: protocol::RequestId,
11+
mut ping_abort_rx: watch::Receiver<()>,
12+
) -> Result<LifecycleResult> {
13+
loop {
14+
tokio::select! {
15+
_ = tokio::time::sleep(UPDATE_PING_INTERVAL) => {}
16+
_ = ping_abort_rx.changed() => {
17+
return Ok(LifecycleResult::Aborted);
18+
}
19+
}
20+
21+
shared_state.send_and_check_ping(request_id).await?;
22+
}
23+
}

0 commit comments

Comments
 (0)