Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions engine/packages/guard-core/src/custom_serve.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use anyhow::{Result, bail};
use async_trait::async_trait;
use bytes::Bytes;

Check warning on line 3 in engine/packages/guard-core/src/custom_serve.rs

View workflow job for this annotation

GitHub Actions / Rustfmt

Diff in /home/runner/work/rivet/rivet/engine/packages/guard-core/src/custom_serve.rs

Check warning on line 3 in engine/packages/guard-core/src/custom_serve.rs

View workflow job for this annotation

GitHub Actions / Rustfmt

Diff in /home/runner/work/rivet/rivet/engine/packages/guard-core/src/custom_serve.rs

Check warning on line 3 in engine/packages/guard-core/src/custom_serve.rs

View workflow job for this annotation

GitHub Actions / Rustfmt

Diff in /home/runner/work/rivet/rivet/engine/packages/guard-core/src/custom_serve.rs
use http_body_util::Full;
use hyper::{Request, Response};
use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
use pegboard::tunnel::id::RequestId;

use crate::WebSocketHandle;
use crate::proxy_service::ResponseBody;
use crate::request_context::RequestContext;

use pegboard::tunnel::id::RequestId;

pub enum HibernationResult {
Continue,
Close,
Expand Down
2 changes: 2 additions & 0 deletions engine/packages/guard/src/routing/api_public.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use bytes::Bytes;
use gas::prelude::*;
use http_body_util::{BodyExt, Full};
use hyper::{Request, Response};
use pegboard::tunnel::id::RequestId;
use rivet_guard_core::proxy_service::{ResponseBody, RoutingOutput};
use rivet_guard_core::{CustomServeTrait, request_context::RequestContext};
use tower::Service;
Expand All @@ -20,6 +21,7 @@ impl CustomServeTrait for ApiPublicService {
&self,
req: Request<Full<Bytes>>,
_request_context: &mut RequestContext,
_request_id: RequestId,
) -> Result<Response<ResponseBody>> {
// Clone the router to get a mutable service
let mut service = self.router.clone();
Expand Down
4 changes: 3 additions & 1 deletion engine/packages/pegboard-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ http-body-util.workspace = true
# TODO: Doesn't match workspace version
hyper = "1.6"
hyper-tungstenite.workspace = true
lazy_static.workspace = true
pegboard.workspace = true
rand.workspace = true
rivet-error.workspace = true
rivet-guard-core.workspace = true
rivet-metrics.workspace = true
rivet-runner-protocol.workspace = true
rivet-util.workspace = true
scc.workspace = true
serde.workspace = true
serde_json.workspace = true
serde.workspace = true
thiserror.workspace = true
tokio-tungstenite.workspace = true
tokio.workspace = true
Expand Down
215 changes: 69 additions & 146 deletions engine/packages/pegboard-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,36 @@ use pegboard::tunnel::id::{self as tunnel_id, RequestId};
use rand::Rng;
use rivet_error::*;
use rivet_guard_core::{
WebSocketHandle,
custom_serve::{CustomServeTrait, HibernationResult},
errors::{
ServiceUnavailable, WebSocketServiceHibernate, WebSocketServiceTimeout,
WebSocketServiceUnavailable,
},
proxy_service::{is_ws_hibernate, ResponseBody},
errors::{ServiceUnavailable, WebSocketServiceUnavailable},
proxy_service::{ResponseBody, is_ws_hibernate},
request_context::RequestContext,
websocket_handle::WebSocketReceiver,
WebSocketHandle,
};
use rivet_runner_protocol as protocol;
use rivet_util::serde::HashableMap;
use std::{sync::Arc, time::Duration};
use tokio::{
sync::{watch, Mutex},
sync::{Mutex, watch},
task::JoinHandle,
};
use tokio_tungstenite::tungstenite::{
protocol::frame::{coding::CloseCode, CloseFrame},
Message,
protocol::frame::{CloseFrame, coding::CloseCode},
};

use crate::shared_state::{InFlightRequestHandle, SharedState};

mod metrics;
mod ping_task;
pub mod shared_state;
mod tunnel_to_ws_task;
mod ws_to_tunnel_task;

const WEBSOCKET_OPEN_TIMEOUT: Duration = Duration::from_secs(15);
const TUNNEL_ACK_TIMEOUT: Duration = Duration::from_secs(5);
const UPDATE_PING_INTERVAL: Duration = Duration::from_secs(3);

#[derive(RivetError, Serialize, Deserialize)]
#[error(
Expand Down Expand Up @@ -391,143 +393,47 @@ impl CustomServeTrait for PegboardGateway {

let ws_rx = client_ws.recv();

let (tunnel_to_ws_abort_tx, mut tunnel_to_ws_abort_rx) = watch::channel(());
let (ws_to_tunnel_abort_tx, mut ws_to_tunnel_abort_rx) = watch::channel(());

// Spawn task to forward messages from tunnel to ws
let shared_state = self.shared_state.clone();
let tunnel_to_ws = tokio::spawn(
async move {
loop {
tokio::select! {
res = msg_rx.recv() => {
if let Some(msg) = res {
match msg {
protocol::ToServerTunnelMessageKind::ToServerWebSocketMessage(ws_msg) => {
let msg = if ws_msg.binary {
Message::Binary(ws_msg.data.into())
} else {
Message::Text(
String::from_utf8_lossy(&ws_msg.data).into_owned().into(),
)
};
client_ws.send(msg).await?;
}
protocol::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => {
tracing::debug!(
request_id=?tunnel_id::request_id_to_string(&request_id),
ack_index=?ack.index,
"received WebSocketMessageAck from runner"
);
shared_state
.ack_pending_websocket_messages(request_id, ack.index)
.await?;
}
protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => {
tracing::debug!(?close, "server closed websocket");

if can_hibernate && close.hibernate {
return Err(WebSocketServiceHibernate.build());
} else {
// Successful closure
return Ok(LifecycleResult::ServerClose(close));
}
}
_ => {}
}
} else {
tracing::debug!("tunnel sub closed");
return Err(WebSocketServiceHibernate.build());
}
}
_ = stopped_sub.next() => {
tracing::debug!("actor stopped during websocket handler loop");

if can_hibernate {
return Err(WebSocketServiceHibernate.build());
} else {
return Err(WebSocketServiceUnavailable.build());
}
}
_ = drop_rx.changed() => {
tracing::warn!("websocket message timeout");
return Err(WebSocketServiceTimeout.build());
}
_ = tunnel_to_ws_abort_rx.changed() => {
tracing::debug!("task aborted");
return Ok(LifecycleResult::Aborted);
}
}
}
}
.instrument(tracing::info_span!("tunnel_to_ws_task")),
);

// Spawn task to forward messages from ws to tunnel
let shared_state_clone = self.shared_state.clone();
let ws_to_tunnel = tokio::spawn(
async move {
let mut ws_rx = ws_rx.lock().await;

loop {
tokio::select! {
res = ws_rx.try_next() => {
if let Some(msg) = res? {
match msg {
Message::Binary(data) => {
let ws_message =
protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage(
protocol::ToClientWebSocketMessage {
data: data.into(),
binary: true,
},
);
shared_state_clone
.send_message(request_id, ws_message)
.await?;
}
Message::Text(text) => {
let ws_message =
protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage(
protocol::ToClientWebSocketMessage {
data: text.as_bytes().to_vec(),
binary: false,
},
);
shared_state_clone
.send_message(request_id, ws_message)
.await?;
}
Message::Close(close) => {
return Ok(LifecycleResult::ClientClose(close));
}
_ => {}
}
} else {
tracing::debug!("websocket stream closed");
return Ok(LifecycleResult::ClientClose(None));
}
}
_ = ws_to_tunnel_abort_rx.changed() => {
tracing::debug!("task aborted");
return Ok(LifecycleResult::Aborted);
}
};
}
}
.instrument(tracing::info_span!("ws_to_tunnel_task")),
);
let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(());
let (ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch::channel(());
let (ping_abort_tx, ping_abort_rx) = watch::channel(());

let tunnel_to_ws = tokio::spawn(tunnel_to_ws_task::task(
self.shared_state.clone(),
client_ws,
request_id,
stopped_sub,
msg_rx,
drop_rx,
can_hibernate,
tunnel_to_ws_abort_rx,
));
let ws_to_tunnel = tokio::spawn(ws_to_tunnel_task::task(
self.shared_state.clone(),
request_id,
ws_rx,
ws_to_tunnel_abort_rx,
));
let ping = tokio::spawn(ping_task::task(
self.shared_state.clone(),
request_id,
ping_abort_rx,
));

let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx.clone();
let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx.clone();
let ping_abort_tx2 = ping_abort_tx.clone();

// Wait for both tasks to complete
let (tunnel_to_ws_res, ws_to_tunnel_res) = tokio::join!(
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) = tokio::join!(
async {
let res = tunnel_to_ws.await?;

// Abort other if not aborted
if !matches!(res, Ok(LifecycleResult::Aborted)) {
tracing::debug!(?res, "tunnel to ws task completed, aborting counterpart");

drop(ws_to_tunnel_abort_tx);
let _ = ping_abort_tx.send(());
let _ = ws_to_tunnel_abort_tx.send(());
} else {
tracing::debug!(?res, "tunnel to ws task completed");
}
Expand All @@ -541,25 +447,42 @@ impl CustomServeTrait for PegboardGateway {
if !matches!(res, Ok(LifecycleResult::Aborted)) {
tracing::debug!(?res, "ws to tunnel task completed, aborting counterpart");

drop(tunnel_to_ws_abort_tx);
let _ = ping_abort_tx2.send(());
let _ = tunnel_to_ws_abort_tx.send(());
} else {
tracing::debug!(?res, "ws to tunnel task completed");
}

res
}
},
async {
let res = ping.await?;

// Abort others if not aborted
if !matches!(res, Ok(LifecycleResult::Aborted)) {
tracing::debug!(?res, "ping task completed, aborting others");

let _ = ws_to_tunnel_abort_tx2.send(());
let _ = tunnel_to_ws_abort_tx2.send(());
} else {
tracing::debug!(?res, "ping task completed");
}

res
},
);

// Determine single result from both tasks
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res) {
// Determine single result from all tasks
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) {
// Prefer error
(Err(err), _) => Err(err),
(_, Err(err)) => Err(err),
(Err(err), _, _) => Err(err),
(_, Err(err), _) => Err(err),
(_, _, Err(err)) => Err(err),
// Prefer non aborted result if both succeed
(Ok(res), Ok(LifecycleResult::Aborted)) => Ok(res),
(Ok(LifecycleResult::Aborted), Ok(res)) => Ok(res),
// Prefer tunnel to ws if both succeed (unlikely case)
(res, _) => res,
(Ok(res), Ok(LifecycleResult::Aborted), _) => Ok(res),
(Ok(LifecycleResult::Aborted), Ok(res), _) => Ok(res),
// Unlikely case
(res, _, _) => res,
};

// Send close frame to runner if not hibernating
Expand Down
14 changes: 14 additions & 0 deletions engine/packages/pegboard-gateway/src/metrics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use rivet_metrics::{
BUCKETS,
otel::{global::*, metrics::*},
};

lazy_static::lazy_static! {
static ref METER: Meter = meter("rivet-gateway");

/// Has no expected attributes
pub static ref TUNNEL_PING_DURATION: Histogram<f64> = METER.f64_histogram("rivet_gateway_tunnel_ping_duration")
.with_description("RTT of messages from gateway to pegboard.")
.with_boundaries(BUCKETS.to_vec())
.build();
}
23 changes: 23 additions & 0 deletions engine/packages/pegboard-gateway/src/ping_task.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use anyhow::Result;
use rivet_runner_protocol as protocol;
use tokio::sync::watch;

use super::{LifecycleResult, UPDATE_PING_INTERVAL};
use crate::shared_state::SharedState;

pub async fn task(
shared_state: SharedState,
request_id: protocol::RequestId,
mut ping_abort_rx: watch::Receiver<()>,
) -> Result<LifecycleResult> {
loop {
tokio::select! {
_ = tokio::time::sleep(UPDATE_PING_INTERVAL) => {}
_ = ping_abort_rx.changed() => {
return Ok(LifecycleResult::Aborted);
}
}

shared_state.send_and_check_ping(request_id).await?;
}
}
Loading
Loading