Skip to content

Commit 633be91

Browse files
committed
fix(tunnel): implement ping pong
1 parent 6521b98 commit 633be91

File tree

16 files changed

+462
-216
lines changed

16 files changed

+462
-216
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/packages/guard-core/src/custom_serve.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@ use bytes::Bytes;
44
use http_body_util::Full;
55
use hyper::{Request, Response};
66
use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
7+
use pegboard::tunnel::id::RequestId;
78

89
use crate::WebSocketHandle;
910
use crate::proxy_service::ResponseBody;
1011
use crate::request_context::RequestContext;
1112

12-
use pegboard::tunnel::id::RequestId;
13-
1413
pub enum HibernationResult {
1514
Continue,
1615
Close,

engine/packages/guard/src/routing/api_public.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use bytes::Bytes;
66
use gas::prelude::*;
77
use http_body_util::{BodyExt, Full};
88
use hyper::{Request, Response};
9+
use pegboard::tunnel::id::RequestId;
910
use rivet_guard_core::proxy_service::{ResponseBody, RoutingOutput};
1011
use rivet_guard_core::{CustomServeTrait, request_context::RequestContext};
1112
use tower::Service;
@@ -20,6 +21,7 @@ impl CustomServeTrait for ApiPublicService {
2021
&self,
2122
req: Request<Full<Bytes>>,
2223
_request_context: &mut RequestContext,
24+
_request_id: RequestId,
2325
) -> Result<Response<ResponseBody>> {
2426
// Clone the router to get a mutable service
2527
let mut service = self.router.clone();

engine/packages/pegboard-gateway/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ http-body-util.workspace = true
1515
# TODO: Doesn't match workspace version
1616
hyper = "1.6"
1717
hyper-tungstenite.workspace = true
18+
lazy_static.workspace = true
1819
pegboard.workspace = true
1920
rand.workspace = true
2021
rivet-error.workspace = true
2122
rivet-guard-core.workspace = true
23+
rivet-metrics.workspace = true
2224
rivet-runner-protocol.workspace = true
2325
rivet-util.workspace = true
2426
scc.workspace = true
25-
serde.workspace = true
2627
serde_json.workspace = true
28+
serde.workspace = true
2729
thiserror.workspace = true
2830
tokio-tungstenite.workspace = true
2931
tokio.workspace = true

engine/packages/pegboard-gateway/src/lib.rs

Lines changed: 69 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,36 @@ use pegboard::tunnel::id::{self as tunnel_id, RequestId};
99
use rand::Rng;
1010
use rivet_error::*;
1111
use rivet_guard_core::{
12+
WebSocketHandle,
1213
custom_serve::{CustomServeTrait, HibernationResult},
13-
errors::{
14-
ServiceUnavailable, WebSocketServiceHibernate, WebSocketServiceTimeout,
15-
WebSocketServiceUnavailable,
16-
},
17-
proxy_service::{is_ws_hibernate, ResponseBody},
14+
errors::{ServiceUnavailable, WebSocketServiceUnavailable},
15+
proxy_service::{ResponseBody, is_ws_hibernate},
1816
request_context::RequestContext,
1917
websocket_handle::WebSocketReceiver,
20-
WebSocketHandle,
2118
};
2219
use rivet_runner_protocol as protocol;
2320
use rivet_util::serde::HashableMap;
2421
use std::{sync::Arc, time::Duration};
2522
use tokio::{
26-
sync::{watch, Mutex},
23+
sync::{Mutex, watch},
2724
task::JoinHandle,
2825
};
2926
use tokio_tungstenite::tungstenite::{
30-
protocol::frame::{coding::CloseCode, CloseFrame},
3127
Message,
28+
protocol::frame::{CloseFrame, coding::CloseCode},
3229
};
3330

3431
use crate::shared_state::{InFlightRequestHandle, SharedState};
3532

33+
mod metrics;
34+
mod ping_task;
3635
pub mod shared_state;
36+
mod tunnel_to_ws_task;
37+
mod ws_to_tunnel_task;
3738

3839
const WEBSOCKET_OPEN_TIMEOUT: Duration = Duration::from_secs(15);
3940
const TUNNEL_ACK_TIMEOUT: Duration = Duration::from_secs(5);
41+
const UPDATE_PING_INTERVAL: Duration = Duration::from_secs(3);
4042

4143
#[derive(RivetError, Serialize, Deserialize)]
4244
#[error(
@@ -391,143 +393,47 @@ impl CustomServeTrait for PegboardGateway {
391393

392394
let ws_rx = client_ws.recv();
393395

394-
let (tunnel_to_ws_abort_tx, mut tunnel_to_ws_abort_rx) = watch::channel(());
395-
let (ws_to_tunnel_abort_tx, mut ws_to_tunnel_abort_rx) = watch::channel(());
396-
397-
// Spawn task to forward messages from tunnel to ws
398-
let shared_state = self.shared_state.clone();
399-
let tunnel_to_ws = tokio::spawn(
400-
async move {
401-
loop {
402-
tokio::select! {
403-
res = msg_rx.recv() => {
404-
if let Some(msg) = res {
405-
match msg {
406-
protocol::ToServerTunnelMessageKind::ToServerWebSocketMessage(ws_msg) => {
407-
let msg = if ws_msg.binary {
408-
Message::Binary(ws_msg.data.into())
409-
} else {
410-
Message::Text(
411-
String::from_utf8_lossy(&ws_msg.data).into_owned().into(),
412-
)
413-
};
414-
client_ws.send(msg).await?;
415-
}
416-
protocol::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => {
417-
tracing::debug!(
418-
request_id=?tunnel_id::request_id_to_string(&request_id),
419-
ack_index=?ack.index,
420-
"received WebSocketMessageAck from runner"
421-
);
422-
shared_state
423-
.ack_pending_websocket_messages(request_id, ack.index)
424-
.await?;
425-
}
426-
protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => {
427-
tracing::debug!(?close, "server closed websocket");
428-
429-
if can_hibernate && close.hibernate {
430-
return Err(WebSocketServiceHibernate.build());
431-
} else {
432-
// Successful closure
433-
return Ok(LifecycleResult::ServerClose(close));
434-
}
435-
}
436-
_ => {}
437-
}
438-
} else {
439-
tracing::debug!("tunnel sub closed");
440-
return Err(WebSocketServiceHibernate.build());
441-
}
442-
}
443-
_ = stopped_sub.next() => {
444-
tracing::debug!("actor stopped during websocket handler loop");
445-
446-
if can_hibernate {
447-
return Err(WebSocketServiceHibernate.build());
448-
} else {
449-
return Err(WebSocketServiceUnavailable.build());
450-
}
451-
}
452-
_ = drop_rx.changed() => {
453-
tracing::warn!("websocket message timeout");
454-
return Err(WebSocketServiceTimeout.build());
455-
}
456-
_ = tunnel_to_ws_abort_rx.changed() => {
457-
tracing::debug!("task aborted");
458-
return Ok(LifecycleResult::Aborted);
459-
}
460-
}
461-
}
462-
}
463-
.instrument(tracing::info_span!("tunnel_to_ws_task")),
464-
);
465-
466-
// Spawn task to forward messages from ws to tunnel
467-
let shared_state_clone = self.shared_state.clone();
468-
let ws_to_tunnel = tokio::spawn(
469-
async move {
470-
let mut ws_rx = ws_rx.lock().await;
471-
472-
loop {
473-
tokio::select! {
474-
res = ws_rx.try_next() => {
475-
if let Some(msg) = res? {
476-
match msg {
477-
Message::Binary(data) => {
478-
let ws_message =
479-
protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage(
480-
protocol::ToClientWebSocketMessage {
481-
data: data.into(),
482-
binary: true,
483-
},
484-
);
485-
shared_state_clone
486-
.send_message(request_id, ws_message)
487-
.await?;
488-
}
489-
Message::Text(text) => {
490-
let ws_message =
491-
protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage(
492-
protocol::ToClientWebSocketMessage {
493-
data: text.as_bytes().to_vec(),
494-
binary: false,
495-
},
496-
);
497-
shared_state_clone
498-
.send_message(request_id, ws_message)
499-
.await?;
500-
}
501-
Message::Close(close) => {
502-
return Ok(LifecycleResult::ClientClose(close));
503-
}
504-
_ => {}
505-
}
506-
} else {
507-
tracing::debug!("websocket stream closed");
508-
return Ok(LifecycleResult::ClientClose(None));
509-
}
510-
}
511-
_ = ws_to_tunnel_abort_rx.changed() => {
512-
tracing::debug!("task aborted");
513-
return Ok(LifecycleResult::Aborted);
514-
}
515-
};
516-
}
517-
}
518-
.instrument(tracing::info_span!("ws_to_tunnel_task")),
519-
);
396+
let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(());
397+
let (ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch::channel(());
398+
let (ping_abort_tx, ping_abort_rx) = watch::channel(());
399+
400+
let tunnel_to_ws = tokio::spawn(tunnel_to_ws_task::task(
401+
self.shared_state.clone(),
402+
client_ws,
403+
request_id,
404+
stopped_sub,
405+
msg_rx,
406+
drop_rx,
407+
can_hibernate,
408+
tunnel_to_ws_abort_rx,
409+
));
410+
let ws_to_tunnel = tokio::spawn(ws_to_tunnel_task::task(
411+
self.shared_state.clone(),
412+
request_id,
413+
ws_rx,
414+
ws_to_tunnel_abort_rx,
415+
));
416+
let ping = tokio::spawn(ping_task::task(
417+
self.shared_state.clone(),
418+
request_id,
419+
ping_abort_rx,
420+
));
421+
422+
let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx.clone();
423+
let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx.clone();
424+
let ping_abort_tx2 = ping_abort_tx.clone();
520425

521426
// Wait for both tasks to complete
522-
let (tunnel_to_ws_res, ws_to_tunnel_res) = tokio::join!(
427+
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) = tokio::join!(
523428
async {
524429
let res = tunnel_to_ws.await?;
525430

526431
// Abort other if not aborted
527432
if !matches!(res, Ok(LifecycleResult::Aborted)) {
528433
tracing::debug!(?res, "tunnel to ws task completed, aborting counterpart");
529434

530-
drop(ws_to_tunnel_abort_tx);
435+
let _ = ping_abort_tx.send(());
436+
let _ = ws_to_tunnel_abort_tx.send(());
531437
} else {
532438
tracing::debug!(?res, "tunnel to ws task completed");
533439
}
@@ -541,25 +447,42 @@ impl CustomServeTrait for PegboardGateway {
541447
if !matches!(res, Ok(LifecycleResult::Aborted)) {
542448
tracing::debug!(?res, "ws to tunnel task completed, aborting counterpart");
543449

544-
drop(tunnel_to_ws_abort_tx);
450+
let _ = ping_abort_tx2.send(());
451+
let _ = tunnel_to_ws_abort_tx.send(());
545452
} else {
546453
tracing::debug!(?res, "ws to tunnel task completed");
547454
}
548455

549456
res
550-
}
457+
},
458+
async {
459+
let res = ping.await?;
460+
461+
// Abort others if not aborted
462+
if !matches!(res, Ok(LifecycleResult::Aborted)) {
463+
tracing::debug!(?res, "ping task completed, aborting others");
464+
465+
let _ = ws_to_tunnel_abort_tx2.send(());
466+
let _ = tunnel_to_ws_abort_tx2.send(());
467+
} else {
468+
tracing::debug!(?res, "ping task completed");
469+
}
470+
471+
res
472+
},
551473
);
552474

553-
// Determine single result from both tasks
554-
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res) {
475+
// Determine single result from all tasks
476+
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) {
555477
// Prefer error
556-
(Err(err), _) => Err(err),
557-
(_, Err(err)) => Err(err),
478+
(Err(err), _, _) => Err(err),
479+
(_, Err(err), _) => Err(err),
480+
(_, _, Err(err)) => Err(err),
558481
// Prefer non aborted result if both succeed
559-
(Ok(res), Ok(LifecycleResult::Aborted)) => Ok(res),
560-
(Ok(LifecycleResult::Aborted), Ok(res)) => Ok(res),
561-
// Prefer tunnel to ws if both succeed (unlikely case)
562-
(res, _) => res,
482+
(Ok(res), Ok(LifecycleResult::Aborted), _) => Ok(res),
483+
(Ok(LifecycleResult::Aborted), Ok(res), _) => Ok(res),
484+
// Unlikely case
485+
(res, _, _) => res,
563486
};
564487

565488
// Send close frame to runner if not hibernating
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use rivet_metrics::{
2+
BUCKETS,
3+
otel::{global::*, metrics::*},
4+
};
5+
6+
lazy_static::lazy_static! {
7+
static ref METER: Meter = meter("rivet-gateway");
8+
9+
/// Has no expected attributes
10+
pub static ref TUNNEL_PING_DURATION: Histogram<f64> = METER.f64_histogram("rivet_gateway_tunnel_ping_duration")
11+
.with_description("RTT of messages from gateway to pegboard.")
12+
.with_boundaries(BUCKETS.to_vec())
13+
.build();
14+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use anyhow::Result;
2+
use rivet_runner_protocol as protocol;
3+
use tokio::sync::watch;
4+
5+
use super::{LifecycleResult, UPDATE_PING_INTERVAL};
6+
use crate::shared_state::SharedState;
7+
8+
pub async fn task(
9+
shared_state: SharedState,
10+
request_id: protocol::RequestId,
11+
mut ping_abort_rx: watch::Receiver<()>,
12+
) -> Result<LifecycleResult> {
13+
loop {
14+
tokio::select! {
15+
_ = tokio::time::sleep(UPDATE_PING_INTERVAL) => {}
16+
_ = ping_abort_rx.changed() => {
17+
return Ok(LifecycleResult::Aborted);
18+
}
19+
}
20+
21+
shared_state.send_and_check_ping(request_id).await?;
22+
}
23+
}

0 commit comments

Comments
 (0)