diff --git a/Cargo.lock b/Cargo.lock index c8f22040e0..1920300f8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3485,10 +3485,12 @@ dependencies = [ "http-body-util", "hyper 1.6.0", "hyper-tungstenite", + "lazy_static", "pegboard", "rand 0.8.5", "rivet-error", "rivet-guard-core", + "rivet-metrics", "rivet-runner-protocol", "rivet-util", "scc", diff --git a/engine/packages/guard-core/src/custom_serve.rs b/engine/packages/guard-core/src/custom_serve.rs index d7518d06ef..1dd684c723 100644 --- a/engine/packages/guard-core/src/custom_serve.rs +++ b/engine/packages/guard-core/src/custom_serve.rs @@ -4,13 +4,12 @@ use bytes::Bytes; use http_body_util::Full; use hyper::{Request, Response}; use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame; +use pegboard::tunnel::id::RequestId; use crate::WebSocketHandle; use crate::proxy_service::ResponseBody; use crate::request_context::RequestContext; -use pegboard::tunnel::id::RequestId; - pub enum HibernationResult { Continue, Close, diff --git a/engine/packages/guard/src/routing/api_public.rs b/engine/packages/guard/src/routing/api_public.rs index f6fee1840c..e7ac2342fa 100644 --- a/engine/packages/guard/src/routing/api_public.rs +++ b/engine/packages/guard/src/routing/api_public.rs @@ -6,6 +6,7 @@ use bytes::Bytes; use gas::prelude::*; use http_body_util::{BodyExt, Full}; use hyper::{Request, Response}; +use pegboard::tunnel::id::RequestId; use rivet_guard_core::proxy_service::{ResponseBody, RoutingOutput}; use rivet_guard_core::{CustomServeTrait, request_context::RequestContext}; use tower::Service; @@ -20,6 +21,7 @@ impl CustomServeTrait for ApiPublicService { &self, req: Request>, _request_context: &mut RequestContext, + _request_id: RequestId, ) -> Result> { // Clone the router to get a mutable service let mut service = self.router.clone(); diff --git a/engine/packages/pegboard-gateway/Cargo.toml b/engine/packages/pegboard-gateway/Cargo.toml index 693bf8de57..31f1d9162c 100644 --- a/engine/packages/pegboard-gateway/Cargo.toml +++ b/engine/packages/pegboard-gateway/Cargo.toml @@ -15,15 +15,17 @@ http-body-util.workspace = true # TODO: Doesn't match workspace version hyper = "1.6" hyper-tungstenite.workspace = true +lazy_static.workspace = true pegboard.workspace = true rand.workspace = true rivet-error.workspace = true rivet-guard-core.workspace = true +rivet-metrics.workspace = true rivet-runner-protocol.workspace = true rivet-util.workspace = true scc.workspace = true -serde.workspace = true serde_json.workspace = true +serde.workspace = true thiserror.workspace = true tokio-tungstenite.workspace = true tokio.workspace = true diff --git a/engine/packages/pegboard-gateway/src/lib.rs b/engine/packages/pegboard-gateway/src/lib.rs index 95cd6e8b35..937d144279 100644 --- a/engine/packages/pegboard-gateway/src/lib.rs +++ b/engine/packages/pegboard-gateway/src/lib.rs @@ -9,34 +9,36 @@ use pegboard::tunnel::id::{self as tunnel_id, RequestId}; use rand::Rng; use rivet_error::*; use rivet_guard_core::{ + WebSocketHandle, custom_serve::{CustomServeTrait, HibernationResult}, - errors::{ - ServiceUnavailable, WebSocketServiceHibernate, WebSocketServiceTimeout, - WebSocketServiceUnavailable, - }, - proxy_service::{is_ws_hibernate, ResponseBody}, + errors::{ServiceUnavailable, WebSocketServiceUnavailable}, + proxy_service::{ResponseBody, is_ws_hibernate}, request_context::RequestContext, websocket_handle::WebSocketReceiver, - WebSocketHandle, }; use rivet_runner_protocol as protocol; use rivet_util::serde::HashableMap; use std::{sync::Arc, time::Duration}; use tokio::{ - sync::{watch, Mutex}, + sync::{Mutex, watch}, task::JoinHandle, }; use tokio_tungstenite::tungstenite::{ - protocol::frame::{coding::CloseCode, CloseFrame}, Message, + protocol::frame::{CloseFrame, coding::CloseCode}, }; use crate::shared_state::{InFlightRequestHandle, SharedState}; +mod metrics; +mod ping_task; pub mod shared_state; +mod tunnel_to_ws_task; +mod ws_to_tunnel_task; const WEBSOCKET_OPEN_TIMEOUT: Duration = Duration::from_secs(15); const TUNNEL_ACK_TIMEOUT: Duration = Duration::from_secs(5); +const UPDATE_PING_INTERVAL: Duration = Duration::from_secs(3); #[derive(RivetError, Serialize, Deserialize)] #[error( @@ -391,135 +393,38 @@ impl CustomServeTrait for PegboardGateway { let ws_rx = client_ws.recv(); - let (tunnel_to_ws_abort_tx, mut tunnel_to_ws_abort_rx) = watch::channel(()); - let (ws_to_tunnel_abort_tx, mut ws_to_tunnel_abort_rx) = watch::channel(()); - - // Spawn task to forward messages from tunnel to ws - let shared_state = self.shared_state.clone(); - let tunnel_to_ws = tokio::spawn( - async move { - loop { - tokio::select! { - res = msg_rx.recv() => { - if let Some(msg) = res { - match msg { - protocol::ToServerTunnelMessageKind::ToServerWebSocketMessage(ws_msg) => { - let msg = if ws_msg.binary { - Message::Binary(ws_msg.data.into()) - } else { - Message::Text( - String::from_utf8_lossy(&ws_msg.data).into_owned().into(), - ) - }; - client_ws.send(msg).await?; - } - protocol::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => { - tracing::debug!( - request_id=?tunnel_id::request_id_to_string(&request_id), - ack_index=?ack.index, - "received WebSocketMessageAck from runner" - ); - shared_state - .ack_pending_websocket_messages(request_id, ack.index) - .await?; - } - protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => { - tracing::debug!(?close, "server closed websocket"); - - if can_hibernate && close.hibernate { - return Err(WebSocketServiceHibernate.build()); - } else { - // Successful closure - return Ok(LifecycleResult::ServerClose(close)); - } - } - _ => {} - } - } else { - tracing::debug!("tunnel sub closed"); - return Err(WebSocketServiceHibernate.build()); - } - } - _ = stopped_sub.next() => { - tracing::debug!("actor stopped during websocket handler loop"); - - if can_hibernate { - return Err(WebSocketServiceHibernate.build()); - } else { - return Err(WebSocketServiceUnavailable.build()); - } - } - _ = drop_rx.changed() => { - tracing::warn!("websocket message timeout"); - return Err(WebSocketServiceTimeout.build()); - } - _ = tunnel_to_ws_abort_rx.changed() => { - tracing::debug!("task aborted"); - return Ok(LifecycleResult::Aborted); - } - } - } - } - .instrument(tracing::info_span!("tunnel_to_ws_task")), - ); - - // Spawn task to forward messages from ws to tunnel - let shared_state_clone = self.shared_state.clone(); - let ws_to_tunnel = tokio::spawn( - async move { - let mut ws_rx = ws_rx.lock().await; - - loop { - tokio::select! { - res = ws_rx.try_next() => { - if let Some(msg) = res? { - match msg { - Message::Binary(data) => { - let ws_message = - protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage( - protocol::ToClientWebSocketMessage { - data: data.into(), - binary: true, - }, - ); - shared_state_clone - .send_message(request_id, ws_message) - .await?; - } - Message::Text(text) => { - let ws_message = - protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage( - protocol::ToClientWebSocketMessage { - data: text.as_bytes().to_vec(), - binary: false, - }, - ); - shared_state_clone - .send_message(request_id, ws_message) - .await?; - } - Message::Close(close) => { - return Ok(LifecycleResult::ClientClose(close)); - } - _ => {} - } - } else { - tracing::debug!("websocket stream closed"); - return Ok(LifecycleResult::ClientClose(None)); - } - } - _ = ws_to_tunnel_abort_rx.changed() => { - tracing::debug!("task aborted"); - return Ok(LifecycleResult::Aborted); - } - }; - } - } - .instrument(tracing::info_span!("ws_to_tunnel_task")), - ); + let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(()); + let (ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch::channel(()); + let (ping_abort_tx, ping_abort_rx) = watch::channel(()); + + let tunnel_to_ws = tokio::spawn(tunnel_to_ws_task::task( + self.shared_state.clone(), + client_ws, + request_id, + stopped_sub, + msg_rx, + drop_rx, + can_hibernate, + tunnel_to_ws_abort_rx, + )); + let ws_to_tunnel = tokio::spawn(ws_to_tunnel_task::task( + self.shared_state.clone(), + request_id, + ws_rx, + ws_to_tunnel_abort_rx, + )); + let ping = tokio::spawn(ping_task::task( + self.shared_state.clone(), + request_id, + ping_abort_rx, + )); + + let tunnel_to_ws_abort_tx2 = tunnel_to_ws_abort_tx.clone(); + let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx.clone(); + let ping_abort_tx2 = ping_abort_tx.clone(); // Wait for both tasks to complete - let (tunnel_to_ws_res, ws_to_tunnel_res) = tokio::join!( + let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) = tokio::join!( async { let res = tunnel_to_ws.await?; @@ -527,7 +432,8 @@ impl CustomServeTrait for PegboardGateway { if !matches!(res, Ok(LifecycleResult::Aborted)) { tracing::debug!(?res, "tunnel to ws task completed, aborting counterpart"); - drop(ws_to_tunnel_abort_tx); + let _ = ping_abort_tx.send(()); + let _ = ws_to_tunnel_abort_tx.send(()); } else { tracing::debug!(?res, "tunnel to ws task completed"); } @@ -541,25 +447,42 @@ impl CustomServeTrait for PegboardGateway { if !matches!(res, Ok(LifecycleResult::Aborted)) { tracing::debug!(?res, "ws to tunnel task completed, aborting counterpart"); - drop(tunnel_to_ws_abort_tx); + let _ = ping_abort_tx2.send(()); + let _ = tunnel_to_ws_abort_tx.send(()); } else { tracing::debug!(?res, "ws to tunnel task completed"); } res - } + }, + async { + let res = ping.await?; + + // Abort others if not aborted + if !matches!(res, Ok(LifecycleResult::Aborted)) { + tracing::debug!(?res, "ping task completed, aborting others"); + + let _ = ws_to_tunnel_abort_tx2.send(()); + let _ = tunnel_to_ws_abort_tx2.send(()); + } else { + tracing::debug!(?res, "ping task completed"); + } + + res + }, ); - // Determine single result from both tasks - let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res) { + // Determine single result from all tasks + let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) { // Prefer error - (Err(err), _) => Err(err), - (_, Err(err)) => Err(err), + (Err(err), _, _) => Err(err), + (_, Err(err), _) => Err(err), + (_, _, Err(err)) => Err(err), // Prefer non aborted result if both succeed - (Ok(res), Ok(LifecycleResult::Aborted)) => Ok(res), - (Ok(LifecycleResult::Aborted), Ok(res)) => Ok(res), - // Prefer tunnel to ws if both succeed (unlikely case) - (res, _) => res, + (Ok(res), Ok(LifecycleResult::Aborted), _) => Ok(res), + (Ok(LifecycleResult::Aborted), Ok(res), _) => Ok(res), + // Unlikely case + (res, _, _) => res, }; // Send close frame to runner if not hibernating diff --git a/engine/packages/pegboard-gateway/src/metrics.rs b/engine/packages/pegboard-gateway/src/metrics.rs new file mode 100644 index 0000000000..7fce45acba --- /dev/null +++ b/engine/packages/pegboard-gateway/src/metrics.rs @@ -0,0 +1,14 @@ +use rivet_metrics::{ + BUCKETS, + otel::{global::*, metrics::*}, +}; + +lazy_static::lazy_static! { + static ref METER: Meter = meter("rivet-gateway"); + + /// Has no expected attributes + pub static ref TUNNEL_PING_DURATION: Histogram = METER.f64_histogram("rivet_gateway_tunnel_ping_duration") + .with_description("RTT of messages from gateway to pegboard.") + .with_boundaries(BUCKETS.to_vec()) + .build(); +} diff --git a/engine/packages/pegboard-gateway/src/ping_task.rs b/engine/packages/pegboard-gateway/src/ping_task.rs new file mode 100644 index 0000000000..01cf19618a --- /dev/null +++ b/engine/packages/pegboard-gateway/src/ping_task.rs @@ -0,0 +1,23 @@ +use anyhow::Result; +use rivet_runner_protocol as protocol; +use tokio::sync::watch; + +use super::{LifecycleResult, UPDATE_PING_INTERVAL}; +use crate::shared_state::SharedState; + +pub async fn task( + shared_state: SharedState, + request_id: protocol::RequestId, + mut ping_abort_rx: watch::Receiver<()>, +) -> Result { + loop { + tokio::select! { + _ = tokio::time::sleep(UPDATE_PING_INTERVAL) => {} + _ = ping_abort_rx.changed() => { + return Ok(LifecycleResult::Aborted); + } + } + + shared_state.send_and_check_ping(request_id).await?; + } +} diff --git a/engine/packages/pegboard-gateway/src/shared_state.rs b/engine/packages/pegboard-gateway/src/shared_state.rs index fc47b50548..a277916663 100644 --- a/engine/packages/pegboard-gateway/src/shared_state.rs +++ b/engine/packages/pegboard-gateway/src/shared_state.rs @@ -1,6 +1,7 @@ use anyhow::Result; use gas::prelude::*; use pegboard::tunnel::id::{self as tunnel_id, GatewayId, RequestId}; +use rivet_guard_core::errors::WebSocketServiceTimeout; use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; use scc::{HashMap, hash_map::Entry}; use std::{ @@ -12,9 +13,10 @@ use tokio::sync::{mpsc, watch}; use universalpubsub::{NextOutput, PubSub, PublishOpts, Subscriber}; use vbare::OwnedVersionedData; -use crate::WebsocketPendingLimitReached; +use crate::{WebsocketPendingLimitReached, metrics}; const GC_INTERVAL: Duration = Duration::from_secs(15); +const TUNNEL_PING_TIMEOUT: i64 = util::duration::seconds(30); const HWS_MESSAGE_ACK_TIMEOUT: Duration = Duration::from_secs(30); const HWS_MAX_PENDING_MSGS_SIZE_PER_REQ: u64 = util::size::mebibytes(1); @@ -40,6 +42,7 @@ struct InFlightRequest { message_index: tunnel_id::MessageIndex, hibernation_state: Option, stopping: bool, + last_pong: i64, } struct HibernationState { @@ -81,6 +84,7 @@ impl SharedState { self.gateway_id } + #[tracing::instrument(skip_all)] pub async fn start(&self) -> Result<()> { let sub = self.ups.subscribe(&self.receiver_subject).await?; @@ -93,6 +97,7 @@ impl SharedState { Ok(()) } + #[tracing::instrument(skip_all, fields(%receiver_subject, request_id=?tunnel_id::request_id_to_string(&request_id)))] pub async fn start_in_flight_request( &self, receiver_subject: String, @@ -111,6 +116,7 @@ impl SharedState { message_index: 0, hibernation_state: None, stopping: false, + last_pong: util::timestamp::now(), }); } Entry::Occupied(mut entry) => { @@ -130,6 +136,7 @@ impl SharedState { InFlightRequestHandle { msg_rx, drop_rx } } + #[tracing::instrument(skip_all, fields(request_id=?tunnel_id::request_id_to_string(&request_id)))] pub async fn send_message( &self, request_id: RequestId, @@ -200,6 +207,43 @@ impl SharedState { Ok(()) } + #[tracing::instrument(skip_all, fields(request_id=?tunnel_id::request_id_to_string(&request_id)))] + pub async fn send_and_check_ping(&self, request_id: RequestId) -> Result<()> { + let req = self + .in_flight_requests + .get_async(&request_id) + .await + .context("request not in flight")?; + + let now = util::timestamp::now(); + + // Verify ping timeout + if now.saturating_sub(req.last_pong) > TUNNEL_PING_TIMEOUT { + tracing::warn!("tunnel timeout"); + return Err(WebSocketServiceTimeout.build()); + } + + // Send message + let message = protocol::ToRunner::ToRunnerPing(protocol::ToRunnerPing { + gateway_id: self.gateway_id, + request_id, + ts: now, + }); + let message_serialized = versioned::ToRunner::wrap_latest(message) + .serialize_with_embedded_version(PROTOCOL_VERSION)?; + + self.ups + .publish( + &req.receiver_subject, + &message_serialized, + PublishOpts::one(), + ) + .await?; + + Ok(()) + } + + #[tracing::instrument(skip_all)] async fn receiver(&self, mut sub: Subscriber) { while let Ok(NextOutput::Message(msg)) = sub.next().await { tracing::trace!( @@ -208,8 +252,22 @@ impl SharedState { ); match versioned::ToGateway::deserialize_with_embedded_version(&msg.payload) { - Ok(protocol::ToGateway::ToGatewayKeepAlive) => { - // No-op + Ok(protocol::ToGateway::ToGatewayPong(pong)) => { + let Some(mut in_flight) = + self.in_flight_requests.get_async(&pong.request_id).await + else { + tracing::debug!( + request_id=?tunnel_id::request_id_to_string(&pong.request_id), + "in flight has already been disconnected, dropping ping" + ); + continue; + }; + + let now = util::timestamp::now(); + in_flight.last_pong = now; + + let rtt = now.saturating_sub(pong.ts); + metrics::TUNNEL_PING_DURATION.record(rtt as f64 * 0.001, &[]); } Ok(protocol::ToGateway::ToServerTunnelMessage(msg)) => { // Parse message ID to extract components @@ -249,6 +307,7 @@ impl SharedState { } } + #[tracing::instrument(skip_all, fields(request_id=?tunnel_id::request_id_to_string(&request_id), %enable))] pub async fn toggle_hibernation(&self, request_id: RequestId, enable: bool) -> Result<()> { let mut req = self .in_flight_requests @@ -271,6 +330,7 @@ impl SharedState { Ok(()) } + #[tracing::instrument(skip_all, fields(request_id=?tunnel_id::request_id_to_string(&request_id)))] pub async fn resend_pending_websocket_messages(&self, request_id: RequestId) -> Result<()> { let Some(mut req) = self.in_flight_requests.get_async(&request_id).await else { bail!("request not in flight"); @@ -293,6 +353,7 @@ impl SharedState { Ok(()) } + #[tracing::instrument(skip_all, fields(request_id=?tunnel_id::request_id_to_string(&request_id)))] pub async fn has_pending_websocket_messages(&self, request_id: RequestId) -> Result { let Some(req) = self.in_flight_requests.get_async(&request_id).await else { bail!("request not in flight"); @@ -305,6 +366,7 @@ impl SharedState { } } + #[tracing::instrument(skip_all, fields(request_id=?tunnel_id::request_id_to_string(&request_id), %ack_index))] pub async fn ack_pending_websocket_messages( &self, request_id: RequestId, @@ -324,22 +386,22 @@ impl SharedState { // Retain messages with index > ack_index (messages that haven't been acknowledged yet) let len_before = hs.pending_ws_msgs.len(); - hs.pending_ws_msgs.retain(|msg| { - wrapping_gt(msg.message_index, ack_index) - }); + hs.pending_ws_msgs + .retain(|msg| wrapping_gt(msg.message_index, ack_index)); let len_after = hs.pending_ws_msgs.len(); tracing::debug!( request_id=?tunnel_id::request_id_to_string(&request_id), ack_index, - removed_count=len_before - len_after, - remaining_count=len_after, + removed_count = len_before - len_after, + remaining_count = len_after, "acked pending websocket messages" ); Ok(()) } + #[tracing::instrument(skip_all)] async fn gc(&self) { let mut interval = tokio::time::interval(GC_INTERVAL); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); @@ -371,13 +433,19 @@ impl SharedState { /// **Phase 2** /// /// 2a. Remove all requests where it was flagged as stopping and `drop_rx` has been dropped + #[tracing::instrument(skip_all)] async fn gc_in_flight_requests(&self) { #[derive(Debug)] enum MsgGcReason { /// Gateway channel is closed and there are no pending messages GatewayClosed, /// WebSocket pending messages (ToServerWebSocketMessageAck) - WebSocketMessageNotAcked { message_index: u16 }, + WebSocketMessageNotAcked { + #[allow(dead_code)] + first_msg_index: u16, + #[allow(dead_code)] + last_msg_index: u16, + }, } let now = Instant::now(); @@ -411,7 +479,10 @@ impl SharedState { if now.duration_since(earliest_pending_ws_msg.send_instant) > HWS_MESSAGE_ACK_TIMEOUT { - break 'reason Some(MsgGcReason::WebSocketMessageNotAcked{message_index: earliest_pending_ws_msg.message_index}); + break 'reason Some(MsgGcReason::WebSocketMessageNotAcked { + first_msg_index: earliest_pending_ws_msg.message_index, + last_msg_index: req.message_index + }); } } diff --git a/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs b/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs new file mode 100644 index 0000000000..8926e4d9f6 --- /dev/null +++ b/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs @@ -0,0 +1,86 @@ +use anyhow::Result; +use gas::prelude::*; +use pegboard::tunnel::id as tunnel_id; +use rivet_guard_core::{ + WebSocketHandle, + errors::{WebSocketServiceHibernate, WebSocketServiceTimeout, WebSocketServiceUnavailable}, +}; +use rivet_runner_protocol as protocol; +use tokio::sync::{mpsc, watch}; +use tokio_tungstenite::tungstenite::Message; + +use super::LifecycleResult; +use crate::shared_state::SharedState; + +pub async fn task( + shared_state: SharedState, + client_ws: WebSocketHandle, + request_id: protocol::RequestId, + mut stopped_sub: message::SubscriptionHandle, + mut msg_rx: mpsc::Receiver, + mut drop_rx: watch::Receiver<()>, + can_hibernate: bool, + mut tunnel_to_ws_abort_rx: watch::Receiver<()>, +) -> Result { + loop { + tokio::select! { + res = msg_rx.recv() => { + if let Some(msg) = res { + match msg { + protocol::ToServerTunnelMessageKind::ToServerWebSocketMessage(ws_msg) => { + let msg = if ws_msg.binary { + Message::Binary(ws_msg.data.into()) + } else { + Message::Text( + String::from_utf8_lossy(&ws_msg.data).into_owned().into(), + ) + }; + client_ws.send(msg).await?; + } + protocol::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => { + tracing::debug!( + request_id=?tunnel_id::request_id_to_string(&request_id), + ack_index=?ack.index, + "received WebSocketMessageAck from runner" + ); + shared_state + .ack_pending_websocket_messages(request_id, ack.index) + .await?; + } + protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => { + tracing::debug!(?close, "server closed websocket"); + + if can_hibernate && close.hibernate { + return Err(WebSocketServiceHibernate.build()); + } else { + // Successful closure + return Ok(LifecycleResult::ServerClose(close)); + } + } + _ => {} + } + } else { + tracing::debug!("tunnel sub closed"); + return Err(WebSocketServiceHibernate.build()); + } + } + _ = stopped_sub.next() => { + tracing::debug!("actor stopped during websocket handler loop"); + + if can_hibernate { + return Err(WebSocketServiceHibernate.build()); + } else { + return Err(WebSocketServiceUnavailable.build()); + } + } + _ = drop_rx.changed() => { + tracing::warn!("websocket message timeout"); + return Err(WebSocketServiceTimeout.build()); + } + _ = tunnel_to_ws_abort_rx.changed() => { + tracing::debug!("task aborted"); + return Ok(LifecycleResult::Aborted); + } + } + } +} diff --git a/engine/packages/pegboard-gateway/src/ws_to_tunnel_task.rs b/engine/packages/pegboard-gateway/src/ws_to_tunnel_task.rs new file mode 100644 index 0000000000..f28ce389e6 --- /dev/null +++ b/engine/packages/pegboard-gateway/src/ws_to_tunnel_task.rs @@ -0,0 +1,65 @@ +use anyhow::Result; +use futures_util::TryStreamExt; +use rivet_guard_core::websocket_handle::WebSocketReceiver; +use rivet_runner_protocol as protocol; +use std::sync::Arc; +use tokio::sync::{Mutex, watch}; +use tokio_tungstenite::tungstenite::Message; + +use super::LifecycleResult; +use crate::shared_state::SharedState; + +pub async fn task( + shared_state: SharedState, + request_id: protocol::RequestId, + ws_rx: Arc>, + mut ws_to_tunnel_abort_rx: watch::Receiver<()>, +) -> Result { + let mut ws_rx = ws_rx.lock().await; + + loop { + tokio::select! { + res = ws_rx.try_next() => { + if let Some(msg) = res? { + match msg { + Message::Binary(data) => { + let ws_message = + protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage( + protocol::ToClientWebSocketMessage { + data: data.into(), + binary: true, + }, + ); + shared_state + .send_message(request_id, ws_message) + .await?; + } + Message::Text(text) => { + let ws_message = + protocol::ToClientTunnelMessageKind::ToClientWebSocketMessage( + protocol::ToClientWebSocketMessage { + data: text.as_bytes().to_vec(), + binary: false, + }, + ); + shared_state + .send_message(request_id, ws_message) + .await?; + } + Message::Close(close) => { + return Ok(LifecycleResult::ClientClose(close)); + } + _ => {} + } + } else { + tracing::debug!("websocket stream closed"); + return Ok(LifecycleResult::ClientClose(None)); + } + } + _ = ws_to_tunnel_abort_rx.changed() => { + tracing::debug!("task aborted"); + return Ok(LifecycleResult::Aborted); + } + }; + } +} diff --git a/engine/packages/pegboard-runner/src/lib.rs b/engine/packages/pegboard-runner/src/lib.rs index 895e9aee19..0632f217de 100644 --- a/engine/packages/pegboard-runner/src/lib.rs +++ b/engine/packages/pegboard-runner/src/lib.rs @@ -212,7 +212,7 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { } ); - // Determine single result from both tasks + // Determine single result from all tasks let lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) { // Prefer error (Err(err), _, _) => Err(err), diff --git a/engine/packages/pegboard-runner/src/ping_task.rs b/engine/packages/pegboard-runner/src/ping_task.rs index 02cfca5a01..68d48d07aa 100644 --- a/engine/packages/pegboard-runner/src/ping_task.rs +++ b/engine/packages/pegboard-runner/src/ping_task.rs @@ -20,41 +20,47 @@ pub async fn task( } } - let Some(wf) = ctx - .workflow::(conn.workflow_id) - .get() - .await? - else { - tracing::error!(?conn.runner_id, "workflow does not exist"); - continue; - }; - - // Check workflow is not dead - if !wf.has_wake_condition { - continue; - } + update_runner_ping(&ctx, &conn).await?; + } +} - // Update ping - let rtt = conn.last_rtt.load(Ordering::Relaxed); - let res = ctx - .op(pegboard::ops::runner::update_alloc_idx::Input { - runners: vec![pegboard::ops::runner::update_alloc_idx::Runner { - runner_id: conn.runner_id, - action: Action::UpdatePing { rtt }, - }], - }) - .await?; - - // If runner became eligible again, then pull any pending actors - for notif in res.notifications { - if let RunnerEligibility::ReEligible = notif.eligibility { - tracing::debug!(runner_id=?notif.runner_id, "runner has become eligible again"); - - ctx.signal(pegboard::workflows::runner::CheckQueue {}) - .to_workflow_id(notif.workflow_id) - .send() - .await?; - } +async fn update_runner_ping(ctx: &StandaloneCtx, conn: &Conn) -> Result<()> { + let Some(wf) = ctx + .workflow::(conn.workflow_id) + .get() + .await? + else { + tracing::error!(?conn.runner_id, "workflow does not exist"); + return Ok(()); + }; + + // Check workflow is not dead + if !wf.has_wake_condition { + return Ok(()); + } + + // Update ping + let rtt = conn.last_rtt.load(Ordering::Relaxed); + let res = ctx + .op(pegboard::ops::runner::update_alloc_idx::Input { + runners: vec![pegboard::ops::runner::update_alloc_idx::Runner { + runner_id: conn.runner_id, + action: Action::UpdatePing { rtt }, + }], + }) + .await?; + + // If runner became eligible again, then pull any pending actors + for notif in res.notifications { + if let RunnerEligibility::ReEligible = notif.eligibility { + tracing::debug!(runner_id=?notif.runner_id, "runner has become eligible again"); + + ctx.signal(pegboard::workflows::runner::CheckQueue {}) + .to_workflow_id(notif.workflow_id) + .send() + .await?; } } + + Ok(()) } diff --git a/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs b/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs index 2a029f54bf..857bc40a71 100644 --- a/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs +++ b/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs @@ -1,10 +1,11 @@ use anyhow::Result; use gas::prelude::*; use hyper_tungstenite::tungstenite::Message as WsMessage; +use pegboard::pubsub_subjects::GatewayReceiverSubject; use rivet_runner_protocol::{self as protocol, versioned}; use std::sync::Arc; use tokio::sync::watch; -use universalpubsub::{NextOutput, Subscriber}; +use universalpubsub::{NextOutput, PublishOpts, Subscriber}; use vbare::OwnedVersionedData; use crate::{LifecycleResult, conn::Conn, errors}; @@ -53,8 +54,29 @@ pub async fn task( // Convert to ToClient types let to_client_msg = match msg { - protocol::ToRunner::ToRunnerKeepAlive(_) => { - // TODO: + protocol::ToRunner::ToRunnerPing(ping) => { + // Publish pong to UPS + let gateway_reply_to = GatewayReceiverSubject::new(ping.gateway_id).to_string(); + let msg_serialized = versioned::ToGateway::wrap_latest( + protocol::ToGateway::ToGatewayPong(protocol::ToGatewayPong { + request_id: ping.request_id, + ts: ping.ts, + }), + ) + .serialize_with_embedded_version(protocol::PROTOCOL_VERSION) + .context("failed to serialize pong message for gateway")?; + ctx.ups() + .context("failed to get UPS instance for tunnel message")? + .publish(&gateway_reply_to, &msg_serialized, PublishOpts::one()) + .await + .with_context(|| { + format!( + "failed to publish tunnel message to gateway reply topic: {}", + gateway_reply_to + ) + })?; + + // Not sent to client continue; } protocol::ToRunner::ToClientInit(x) => protocol::ToClient::ToClientInit(x), diff --git a/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs b/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs index 1758695d6c..f37efc3474 100644 --- a/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs +++ b/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs @@ -91,10 +91,11 @@ async fn handle_message( match msg { protocol::ToServer::ToServerPing(ping) => { let now = util::timestamp::now(); - let rtt = if ping.ts <= now { - // Calculate RTT, clamping to u32::MAX if too large - let rtt_ms = now.saturating_sub(ping.ts); - rtt_ms.min(u32::MAX as i64) as u32 + + let delta = if ping.ts <= now { + // Calculate delta, clamping to u32::MAX if too large + let delta_ms = now.saturating_sub(ping.ts); + delta_ms.min(u32::MAX as i64) as u32 } else { // If ping timestamp is in the future (clock skew), default to 0 tracing::warn!( @@ -105,6 +106,9 @@ async fn handle_message( 0 }; + // Assuming symmetric delta + let rtt = delta * 2; + conn.last_rtt.store(rtt, Ordering::Relaxed); } // Process KV request diff --git a/engine/packages/pegboard-serverless/src/lib.rs b/engine/packages/pegboard-serverless/src/lib.rs index 19d9c9386c..5591254f2f 100644 --- a/engine/packages/pegboard-serverless/src/lib.rs +++ b/engine/packages/pegboard-serverless/src/lib.rs @@ -511,8 +511,8 @@ async fn publish_to_client_stop(ctx: &StandaloneCtx, runner_id: Id) -> Result<() let receiver_subject = pegboard::pubsub_subjects::RunnerReceiverSubject::new(runner_id).to_string(); - let message_serialized = rivet_runner_protocol::versioned::ToClient::wrap_latest( - rivet_runner_protocol::ToClient::ToClientClose, + let message_serialized = rivet_runner_protocol::versioned::ToRunner::wrap_latest( + rivet_runner_protocol::ToRunner::ToClientClose, ) .serialize_with_embedded_version(rivet_runner_protocol::PROTOCOL_VERSION)?; diff --git a/engine/packages/pegboard/src/tunnel/id.rs b/engine/packages/pegboard/src/tunnel/id.rs index 6bad4de4ed..67a1582e39 100644 --- a/engine/packages/pegboard/src/tunnel/id.rs +++ b/engine/packages/pegboard/src/tunnel/id.rs @@ -6,7 +6,7 @@ use rivet_runner_protocol as protocol; pub type GatewayId = [u8; 4]; pub type RequestId = [u8; 4]; pub type MessageIndex = u16; -pub type MessageId = [u8; 12]; +pub type MessageId = [u8; 10]; /// Generate a new 4-byte gateway ID from a random u32 pub fn generate_gateway_id() -> GatewayId { @@ -26,15 +26,15 @@ pub fn build_message_id( }; // Serialize directly to a fixed-size buffer on the stack - let mut message_id = [0u8; 12]; + let mut message_id = [0u8; 10]; let mut cursor = std::io::Cursor::new(&mut message_id[..]); serde_bare::to_writer(&mut cursor, &parts).context("failed to serialize message id parts")?; - // Verify we wrote exactly 12 bytes + // Verify we wrote exactly 10 bytes let written = cursor.position() as usize; ensure!( - written == 12, - "message id serialization produced wrong size: expected 12 bytes, got {}", + written == 10, + "message id serialization produced wrong size: expected 10 bytes, got {}", written ); diff --git a/engine/packages/pegboard/src/workflows/runner.rs b/engine/packages/pegboard/src/workflows/runner.rs index a0a4eb0276..b87ef3acc5 100644 --- a/engine/packages/pegboard/src/workflows/runner.rs +++ b/engine/packages/pegboard/src/workflows/runner.rs @@ -2,7 +2,7 @@ use futures_util::{FutureExt, StreamExt, TryStreamExt}; use gas::prelude::*; use rivet_data::converted::{ActorNameKeyData, MetadataKeyData, RunnerByKeyKeyData}; use rivet_metrics::KeyValue; -use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; +use rivet_runner_protocol::{self as protocol, versioned, PROTOCOL_VERSION}; use universaldb::{ options::{ConflictRangeType, StreamingMode}, utils::{FormalChunkedKey, IsolationLevel::*}, @@ -97,7 +97,7 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> // Send init packet ctx.activity(SendMessageToRunnerInput { runner_id: input.runner_id, - message: protocol::ToClient::ToClientInit(protocol::ToClientInit { + message: protocol::ToRunner::ToClientInit(protocol::ToClientInit { runner_id: input.runner_id.to_string(), last_event_idx: init_data.last_event_idx, metadata: protocol::ProtocolMetadata { @@ -111,7 +111,7 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> if !init_data.missed_commands.is_empty() { ctx.activity(SendMessageToRunnerInput { runner_id: input.runner_id, - message: protocol::ToClient::ToClientCommands( + message: protocol::ToRunner::ToClientCommands( init_data.missed_commands, ), }) @@ -203,7 +203,7 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> ctx.activity(SendMessageToRunnerInput { runner_id: input.runner_id, - message: protocol::ToClient::ToClientAckEvents( + message: protocol::ToRunner::ToClientAckEvents( protocol::ToClientAckEvents { last_event_idx: state.last_event_ack_idx, }, @@ -282,7 +282,7 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> // Forward ctx.activity(SendMessageToRunnerInput { runner_id: input.runner_id, - message: protocol::ToClient::ToClientCommands(vec![ + message: protocol::ToRunner::ToClientCommands(vec![ protocol::CommandWrapper { index, inner: command.inner, @@ -376,7 +376,7 @@ pub async fn pegboard_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> // Close websocket connection (its unlikely to be open) ctx.activity(SendMessageToRunnerInput { runner_id: input.runner_id, - message: protocol::ToClient::ToClientClose, + message: protocol::ToRunner::ToClientClose, }) .await?; @@ -1137,7 +1137,7 @@ pub(crate) async fn allocate_pending_actors( #[derive(Debug, Serialize, Deserialize, Hash)] struct SendMessageToRunnerInput { runner_id: Id, - message: protocol::ToClient, + message: protocol::ToRunner, } #[activity(SendMessageToRunner)] @@ -1145,7 +1145,7 @@ async fn send_message_to_runner(ctx: &ActivityCtx, input: &SendMessageToRunnerIn let receiver_subject = crate::pubsub_subjects::RunnerReceiverSubject::new(input.runner_id).to_string(); - let message_serialized = versioned::ToClient::wrap_latest(input.message.clone()) + let message_serialized = versioned::ToRunner::wrap_latest(input.message.clone()) .serialize_with_embedded_version(PROTOCOL_VERSION)?; ctx.ups()? diff --git a/engine/sdks/rust/runner-protocol/src/versioned.rs b/engine/sdks/rust/runner-protocol/src/versioned.rs index aa4dfae195..a832939f2f 100644 --- a/engine/sdks/rust/runner-protocol/src/versioned.rs +++ b/engine/sdks/rust/runner-protocol/src/versioned.rs @@ -175,10 +175,10 @@ impl ToClient { } v2::ToClient::ToClientTunnelMessage(msg) => { // Extract v3 message_id from v2's message_id - // v3: gateway_id (4) + request_id (4) + message_index (4) = 12 bytes - // v2.message_id contains: entire v3 message_id (12 bytes) + padding (4 bytes) - let mut message_id = [0u8; 12]; - message_id.copy_from_slice(&msg.message_id[..12]); + // v3: gateway_id (4) + request_id (4) + message_index (2) = 10 bytes + // v2.message_id contains: entire v3 message_id (10 bytes) + padding (6 bytes) + let mut message_id = [0u8; 10]; + message_id.copy_from_slice(&msg.message_id[..10]); v3::ToClient::ToClientTunnelMessage(v3::ToClientTunnelMessage { message_id, @@ -247,13 +247,13 @@ impl ToClient { } v3::ToClient::ToClientTunnelMessage(msg) => { // Split v3 message_id into v2's request_id and message_id - // v3: gateway_id (4) + request_id (4) + message_index (4) = 12 bytes + // v3: gateway_id (4) + request_id (4) + message_index (2) = 10 bytes // v2.request_id = gateway_id (4) + request_id (4) + padding (8 zeros) - // v2.message_id = entire v3 message_id (12 bytes) + padding (4 zeros) + // v2.message_id = entire v3 message_id (10 bytes) + padding (4 zeros) let mut request_id = [0u8; 16]; let mut message_id = [0u8; 16]; request_id[..8].copy_from_slice(&msg.message_id[..8]); // gateway_id + request_id - message_id[..12].copy_from_slice(&msg.message_id); // entire v3 message_id + message_id[..10].copy_from_slice(&msg.message_id); // entire v3 message_id v2::ToClient::ToClientTunnelMessage(v2::ToClientTunnelMessage { request_id, @@ -503,10 +503,10 @@ impl ToServer { } v2::ToServer::ToServerTunnelMessage(msg) => { // Extract v3 message_id from v2's message_id - // v3: gateway_id (4) + request_id (4) + message_index (4) = 12 bytes - // v2.message_id contains: entire v3 message_id (12 bytes) + padding (4 bytes) - let mut message_id = [0u8; 12]; - message_id.copy_from_slice(&msg.message_id[..12]); + // v3: gateway_id (4) + request_id (4) + message_index (2) = 10 bytes + // v2.message_id contains: entire v3 message_id (10 bytes) + padding (6 bytes) + let mut message_id = [0u8; 10]; + message_id.copy_from_slice(&msg.message_id[..10]); v3::ToServer::ToServerTunnelMessage(v3::ToServerTunnelMessage { message_id, @@ -572,13 +572,13 @@ impl ToServer { } v3::ToServer::ToServerTunnelMessage(msg) => { // Split v3 message_id into v2's request_id and message_id - // v3: gateway_id (4) + request_id (4) + message_index (4) = 12 bytes + // v3: gateway_id (4) + request_id (4) + message_index (2) = 10 bytes // v2.request_id = gateway_id (4) + request_id (4) + padding (8 zeros) - // v2.message_id = entire v3 message_id (12 bytes) + padding (4 zeros) + // v2.message_id = entire v3 message_id (10 bytes) + padding (4 zeros) let mut request_id = [0u8; 16]; let mut message_id = [0u8; 16]; request_id[..8].copy_from_slice(&msg.message_id[..8]); // gateway_id + request_id - message_id[..12].copy_from_slice(&msg.message_id); // entire v3 message_id + message_id[..10].copy_from_slice(&msg.message_id); // entire v3 message_id v2::ToServer::ToServerTunnelMessage(v2::ToServerTunnelMessage { request_id, @@ -1262,7 +1262,7 @@ fn convert_to_client_tunnel_message_kind_v2_to_v3( fn convert_to_client_tunnel_message_kind_v3_to_v2( kind: v3::ToClientTunnelMessageKind, - message_id: &[u8; 12], + message_id: &[u8; 10], ) -> Result { Ok(match kind { v3::ToClientTunnelMessageKind::ToClientRequestStart(req) => { diff --git a/engine/sdks/schemas/runner-protocol/v3.bare b/engine/sdks/schemas/runner-protocol/v3.bare index 53c258cece..207b01c828 100644 --- a/engine/sdks/schemas/runner-protocol/v3.bare +++ b/engine/sdks/schemas/runner-protocol/v3.bare @@ -213,7 +213,7 @@ type MessageIdParts struct { messageIndex: MessageIndex } -type MessageId data[12] +type MessageId data[10] # Ack (deprecated, older protocols that have gc cycles to check for tunnel ack) @@ -399,14 +399,16 @@ type ToClient union { } # MARK: To Runner -type ToRunnerKeepAlive struct { +type ToRunnerPing struct { + gatewayId: GatewayId requestId: RequestId + ts: i64 } # We have to re-declare the entire union since BARE will not generate the # ser/de for ToClient if it's not a top-level type type ToRunner union { - ToRunnerKeepAlive | + ToRunnerPing | ToClientInit | ToClientClose | ToClientCommands | @@ -416,10 +418,13 @@ type ToRunner union { } # MARK: To Gateway -type ToGatewayKeepAlive void +type ToGatewayPong struct { + requestId: RequestId + ts: i64 +} type ToGateway union { - ToGatewayKeepAlive | + ToGatewayPong | ToServerTunnelMessage } diff --git a/engine/sdks/typescript/runner-protocol/src/index.ts b/engine/sdks/typescript/runner-protocol/src/index.ts index 3e1d938483..ffa2222ce5 100644 --- a/engine/sdks/typescript/runner-protocol/src/index.ts +++ b/engine/sdks/typescript/runner-protocol/src/index.ts @@ -1018,11 +1018,11 @@ export function decodeMessageIdParts(bytes: Uint8Array): MessageIdParts { export type MessageId = ArrayBuffer export function readMessageId(bc: bare.ByteCursor): MessageId { - return bare.readFixedData(bc, 12) + return bare.readFixedData(bc, 10) } export function writeMessageId(bc: bare.ByteCursor, x: MessageId): void { - assert(x.byteLength === 12) + assert(x.byteLength === 10) bare.writeFixedData(bc, x) } @@ -1899,18 +1899,24 @@ export function decodeToClient(bytes: Uint8Array): ToClient { /** * MARK: To Runner */ -export type ToRunnerKeepAlive = { +export type ToRunnerPing = { + readonly gatewayId: GatewayId readonly requestId: RequestId + readonly ts: i64 } -export function readToRunnerKeepAlive(bc: bare.ByteCursor): ToRunnerKeepAlive { +export function readToRunnerPing(bc: bare.ByteCursor): ToRunnerPing { return { + gatewayId: readGatewayId(bc), requestId: readRequestId(bc), + ts: bare.readI64(bc), } } -export function writeToRunnerKeepAlive(bc: bare.ByteCursor, x: ToRunnerKeepAlive): void { +export function writeToRunnerPing(bc: bare.ByteCursor, x: ToRunnerPing): void { + writeGatewayId(bc, x.gatewayId) writeRequestId(bc, x.requestId) + bare.writeI64(bc, x.ts) } /** @@ -1918,7 +1924,7 @@ export function writeToRunnerKeepAlive(bc: bare.ByteCursor, x: ToRunnerKeepAlive * ser/de for ToClient if it's not a top-level type */ export type ToRunner = - | { readonly tag: "ToRunnerKeepAlive"; readonly val: ToRunnerKeepAlive } + | { readonly tag: "ToRunnerPing"; readonly val: ToRunnerPing } | { readonly tag: "ToClientInit"; readonly val: ToClientInit } | { readonly tag: "ToClientClose"; readonly val: ToClientClose } | { readonly tag: "ToClientCommands"; readonly val: ToClientCommands } @@ -1931,7 +1937,7 @@ export function readToRunner(bc: bare.ByteCursor): ToRunner { const tag = bare.readU8(bc) switch (tag) { case 0: - return { tag: "ToRunnerKeepAlive", val: readToRunnerKeepAlive(bc) } + return { tag: "ToRunnerPing", val: readToRunnerPing(bc) } case 1: return { tag: "ToClientInit", val: readToClientInit(bc) } case 2: @@ -1953,9 +1959,9 @@ export function readToRunner(bc: bare.ByteCursor): ToRunner { export function writeToRunner(bc: bare.ByteCursor, x: ToRunner): void { switch (x.tag) { - case "ToRunnerKeepAlive": { + case "ToRunnerPing": { bare.writeU8(bc, 0) - writeToRunnerKeepAlive(bc, x.val) + writeToRunnerPing(bc, x.val) break } case "ToClientInit": { @@ -2012,10 +2018,25 @@ export function decodeToRunner(bytes: Uint8Array): ToRunner { /** * MARK: To Gateway */ -export type ToGatewayKeepAlive = null +export type ToGatewayPong = { + readonly requestId: RequestId + readonly ts: i64 +} + +export function readToGatewayPong(bc: bare.ByteCursor): ToGatewayPong { + return { + requestId: readRequestId(bc), + ts: bare.readI64(bc), + } +} + +export function writeToGatewayPong(bc: bare.ByteCursor, x: ToGatewayPong): void { + writeRequestId(bc, x.requestId) + bare.writeI64(bc, x.ts) +} export type ToGateway = - | { readonly tag: "ToGatewayKeepAlive"; readonly val: ToGatewayKeepAlive } + | { readonly tag: "ToGatewayPong"; readonly val: ToGatewayPong } | { readonly tag: "ToServerTunnelMessage"; readonly val: ToServerTunnelMessage } export function readToGateway(bc: bare.ByteCursor): ToGateway { @@ -2023,7 +2044,7 @@ export function readToGateway(bc: bare.ByteCursor): ToGateway { const tag = bare.readU8(bc) switch (tag) { case 0: - return { tag: "ToGatewayKeepAlive", val: null } + return { tag: "ToGatewayPong", val: readToGatewayPong(bc) } case 1: return { tag: "ToServerTunnelMessage", val: readToServerTunnelMessage(bc) } default: { @@ -2035,8 +2056,9 @@ export function readToGateway(bc: bare.ByteCursor): ToGateway { export function writeToGateway(bc: bare.ByteCursor, x: ToGateway): void { switch (x.tag) { - case "ToGatewayKeepAlive": { + case "ToGatewayPong": { bare.writeU8(bc, 0) + writeToGatewayPong(bc, x.val) break } case "ToServerTunnelMessage": { diff --git a/engine/sdks/typescript/runner/src/mod.ts b/engine/sdks/typescript/runner/src/mod.ts index 86a223e76b..815ab56207 100644 --- a/engine/sdks/typescript/runner/src/mod.ts +++ b/engine/sdks/typescript/runner/src/mod.ts @@ -2,8 +2,7 @@ import * as protocol from "@rivetkit/engine-runner-protocol"; import type { Logger } from "pino"; import type WebSocket from "ws"; import { logger, setLogger } from "./log.js"; -import { stringifyCommandWrapper, stringifyEvent } from "./stringify"; -import type { PendingRequest } from "./tunnel"; +import { stringifyToClient, stringifyToServer } from "./stringify"; import { type HibernatingWebSocketMetadata, Tunnel } from "./tunnel"; import { calculateBackoff, @@ -11,7 +10,6 @@ import { unreachable, } from "./utils"; import { importWebSocket } from "./websocket.js"; -import type { WebSocketTunnelAdapter } from "./websocket-tunnel-adapter"; import { RunnerActor, type ActorConfig } from "./actor"; export type { HibernatingWebSocketMetadata }; @@ -501,20 +499,10 @@ export class Runner { // NOTE: We don't use #sendToServer here because that function checks if the runner is // shut down - const encoded = protocol.encodeToServer({ + this.__sendToServer({ tag: "ToServerStopping", val: null, }); - if ( - this.#pegboardWebSocket && - this.#pegboardWebSocket.readyState === 1 - ) { - this.#pegboardWebSocket.send(encoded); - } else { - this.log?.error( - "WebSocket not available or not open for sending data", - ); - } const closePromise = new Promise((resolve) => { if (!pegboardWebSocket) @@ -708,6 +696,10 @@ export class Runner { // Parse message const message = protocol.decodeToClient(buf); + this.log?.debug({ + msg: "received runner message", + data: stringifyToClient(message), + }); // Handle message if (message.tag === "ToClientInit") { @@ -849,10 +841,6 @@ export class Runner { }); for (const commandWrapper of commands) { - this.log?.info({ - msg: "received command", - command: stringifyCommandWrapper(commandWrapper), - }); if (commandWrapper.inner.tag === "CommandStartActor") { // Spawn background promise this.#handleCommandStartActor(commandWrapper); @@ -1008,12 +996,6 @@ export class Runner { this.#recordEvent(eventWrapper); - this.log?.info({ - msg: "sending event to server", - event: stringifyEvent(eventWrapper.inner), - index: eventWrapper.index.toString(), - }); - this.__sendToServer({ tag: "ToServerEvents", val: [eventWrapper], @@ -1064,12 +1046,6 @@ export class Runner { this.#recordEvent(eventWrapper); - this.log?.info({ - msg: "sending event to server", - event: stringifyEvent(eventWrapper.inner), - index: eventWrapper.index.toString(), - }); - this.__sendToServer({ tag: "ToServerEvents", val: [eventWrapper], @@ -1510,14 +1486,19 @@ export class Runner { : false; } - __sendToServer(message: protocol.ToServer) { - if (this.#shutdown) { + __sendToServer(message: protocol.ToServer, allowShutdown: boolean = false) { + if (!allowShutdown && this.#shutdown) { this.log?.warn({ msg: "Runner is shut down, cannot send message to server", }); return; } + this.log?.debug({ + msg: "sending runner message", + data: stringifyToServer(message), + }); + const encoded = protocol.encodeToServer(message); if ( this.#pegboardWebSocket && diff --git a/engine/sdks/typescript/runner/src/stringify.ts b/engine/sdks/typescript/runner/src/stringify.ts index 25f345a789..4f150fecc0 100644 --- a/engine/sdks/typescript/runner/src/stringify.ts +++ b/engine/sdks/typescript/runner/src/stringify.ts @@ -182,3 +182,154 @@ export function stringifyEvent(event: protocol.Event): string { export function stringifyEventWrapper(wrapper: protocol.EventWrapper): string { return `EventWrapper{index: ${stringifyBigInt(wrapper.index)}, inner: ${stringifyEvent(wrapper.inner)}}`; } + +/** + * Stringify ToServer for logging + * Handles ArrayBuffers, BigInts, and Maps that can't be JSON.stringified + */ +export function stringifyToServer(message: protocol.ToServer): string { + switch (message.tag) { + case "ToServerInit": { + const { name, version, totalSlots, lastCommandIdx, metadata } = + message.val; + const lastCommandIdxStr = + lastCommandIdx === null + ? "null" + : stringifyBigInt(lastCommandIdx); + const metadataStr = metadata === null ? "null" : `"${metadata}"`; + return `ToServerInit{name: "${name}", version: ${version}, totalSlots: ${totalSlots}, lastCommandIdx: ${lastCommandIdxStr}, metadata: ${metadataStr}}`; + } + case "ToServerEvents": { + const events = message.val; + return `ToServerEvents{count: ${events.length}, events: [${events.map((e) => stringifyEventWrapper(e)).join(", ")}]}`; + } + case "ToServerAckCommands": { + const { lastCommandIdx } = message.val; + return `ToServerAckCommands{lastCommandIdx: ${stringifyBigInt(lastCommandIdx)}}`; + } + case "ToServerStopping": + return "ToServerStopping"; + case "ToServerPing": { + const { ts } = message.val; + return `ToServerPing{ts: ${stringifyBigInt(ts)}}`; + } + case "ToServerKvRequest": { + const { actorId, requestId, data } = message.val; + const dataStr = stringifyKvRequestData(data); + return `ToServerKvRequest{actorId: "${actorId}", requestId: ${requestId}, data: ${dataStr}}`; + } + case "ToServerTunnelMessage": { + const { messageId, messageKind } = message.val; + const messageIdStr = stringifyArrayBuffer(messageId); + return `ToServerTunnelMessage{messageId: ${messageIdStr}, messageKind: ${stringifyToServerTunnelMessageKind(messageKind)}}`; + } + } +} + +/** + * Stringify ToClient for logging + * Handles ArrayBuffers, BigInts, and Maps that can't be JSON.stringified + */ +export function stringifyToClient(message: protocol.ToClient): string { + switch (message.tag) { + case "ToClientInit": { + const { runnerId, lastEventIdx, metadata } = message.val; + const runnerLostThreshold = metadata?.runnerLostThreshold + ? stringifyBigInt(metadata.runnerLostThreshold) + : "null"; + return `ToClientInit{runnerId: "${runnerId}", lastEventIdx: ${stringifyBigInt(lastEventIdx)}, runnerLostThreshold: ${runnerLostThreshold}}`; + } + case "ToClientClose": + return "ToClientClose"; + case "ToClientCommands": { + const commands = message.val; + return `ToClientCommands{count: ${commands.length}, commands: [${commands.map((c) => stringifyCommandWrapper(c)).join(", ")}]}`; + } + case "ToClientAckEvents": { + const { lastEventIdx } = message.val; + return `ToClientAckEvents{lastEventIdx: ${stringifyBigInt(lastEventIdx)}}`; + } + case "ToClientKvResponse": { + const { requestId, data } = message.val; + const dataStr = stringifyKvResponseData(data); + return `ToClientKvResponse{requestId: ${requestId}, data: ${dataStr}}`; + } + case "ToClientTunnelMessage": { + const { messageId, messageKind } = message.val; + const messageIdStr = stringifyArrayBuffer(messageId); + return `ToClientTunnelMessage{messageId: ${messageIdStr}, messageKind: ${stringifyToClientTunnelMessageKind(messageKind)}}`; + } + } +} + +/** + * Stringify KvRequestData for logging + */ +function stringifyKvRequestData(data: protocol.KvRequestData): string { + switch (data.tag) { + case "KvGetRequest": { + const { keys } = data.val; + return `KvGetRequest{keys: ${keys.length}}`; + } + case "KvListRequest": { + const { query, reverse, limit } = data.val; + const reverseStr = reverse === null ? "null" : reverse.toString(); + const limitStr = limit === null ? "null" : stringifyBigInt(limit); + return `KvListRequest{query: ${stringifyKvListQuery(query)}, reverse: ${reverseStr}, limit: ${limitStr}}`; + } + case "KvPutRequest": { + const { keys, values } = data.val; + return `KvPutRequest{keys: ${keys.length}, values: ${values.length}}`; + } + case "KvDeleteRequest": { + const { keys } = data.val; + return `KvDeleteRequest{keys: ${keys.length}}`; + } + case "KvDropRequest": + return "KvDropRequest"; + } +} + +/** + * Stringify KvListQuery for logging + */ +function stringifyKvListQuery(query: protocol.KvListQuery): string { + switch (query.tag) { + case "KvListAllQuery": + return "KvListAllQuery"; + case "KvListRangeQuery": { + const { start, end, exclusive } = query.val; + return `KvListRangeQuery{start: ${stringifyArrayBuffer(start)}, end: ${stringifyArrayBuffer(end)}, exclusive: ${exclusive}}`; + } + case "KvListPrefixQuery": { + const { key } = query.val; + return `KvListPrefixQuery{key: ${stringifyArrayBuffer(key)}}`; + } + } +} + +/** + * Stringify KvResponseData for logging + */ +function stringifyKvResponseData(data: protocol.KvResponseData): string { + switch (data.tag) { + case "KvErrorResponse": { + const { message } = data.val; + return `KvErrorResponse{message: "${message}"}`; + } + case "KvGetResponse": { + const { keys, values, metadata } = data.val; + return `KvGetResponse{keys: ${keys.length}, values: ${values.length}, metadata: ${metadata.length}}`; + } + case "KvListResponse": { + const { keys, values, metadata } = data.val; + return `KvListResponse{keys: ${keys.length}, values: ${values.length}, metadata: ${metadata.length}}`; + } + case "KvPutResponse": + return "KvPutResponse"; + case "KvDeleteResponse": + return "KvDeleteResponse"; + case "KvDropResponse": + return "KvDropResponse"; + } +} diff --git a/engine/sdks/typescript/runner/src/tunnel-id.ts b/engine/sdks/typescript/runner/src/tunnel-id.ts index bea04a2453..8474e4e4cb 100644 --- a/engine/sdks/typescript/runner/src/tunnel-id.ts +++ b/engine/sdks/typescript/runner/src/tunnel-id.ts @@ -38,14 +38,14 @@ export function buildMessageId( const encoded = protocol.encodeMessageIdParts(parts); - if (encoded.byteLength !== 12) { + if (encoded.byteLength !== 10) { throw new Error( - `message id serialization produced wrong size: expected 12 bytes, got ${encoded.byteLength}`, + `message id serialization produced wrong size: expected 10 bytes, got ${encoded.byteLength}`, ); } // Create a new ArrayBuffer from the Uint8Array - const messageId = new ArrayBuffer(12); + const messageId = new ArrayBuffer(10); new Uint8Array(messageId).set(encoded); return messageId; } @@ -54,9 +54,9 @@ export function buildMessageId( * Parse a MessageId into its components */ export function parseMessageId(messageId: MessageId): protocol.MessageIdParts { - if (messageId.byteLength !== 12) { + if (messageId.byteLength !== 10) { throw new Error( - `invalid message id length: expected 12 bytes, got ${messageId.byteLength}`, + `invalid message id length: expected 10 bytes, got ${messageId.byteLength}`, ); } const uint8Array = new Uint8Array(messageId); diff --git a/engine/sdks/typescript/runner/src/tunnel.ts b/engine/sdks/typescript/runner/src/tunnel.ts index 16d836003d..5dbb49741a 100644 --- a/engine/sdks/typescript/runner/src/tunnel.ts +++ b/engine/sdks/typescript/runner/src/tunnel.ts @@ -804,7 +804,7 @@ export class Tunnel { requestId: ArrayBuffer, response: Response, ) { - if (this.#runner.hasActor(actorId, generation)) { + if (!this.#runner.hasActor(actorId, generation)) { this.log?.warn({ msg: "actor not loaded to send response, assuming gateway has closed request", actorId, @@ -854,7 +854,7 @@ export class Tunnel { status: number, message: string, ) { - if (this.#runner.hasActor(actorId, generation)) { + if (!this.#runner.hasActor(actorId, generation)) { this.log?.warn({ msg: "actor not loaded to send response, assuming gateway has closed request", actorId, diff --git a/rivetkit-typescript/packages/rivetkit/src/common/utils.ts b/rivetkit-typescript/packages/rivetkit/src/common/utils.ts index e06b209391..0f505386e1 100644 --- a/rivetkit-typescript/packages/rivetkit/src/common/utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/common/utils.ts @@ -245,6 +245,7 @@ export function deconstructError( group, code, message, + stack: (error as Error)?.stack, ...EXTRA_ERROR_LOG, ...extraLog, }); @@ -260,6 +261,7 @@ export function deconstructError( group, code, message, + stack: (error as Error)?.stack, ...EXTRA_ERROR_LOG, ...extraLog, }); diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/utils.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/utils.ts index 0ebbbdf2d7..028bce6d29 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/utils.ts @@ -1,4 +1,3 @@ -import invariant from "invariant"; import { type TestContext, vi } from "vitest"; import { assertUnreachable } from "@/actor/utils"; import { type Client, createClient } from "@/client/mod"; @@ -7,6 +6,7 @@ import { createClientWithDriver } from "@/mod"; import type { registry } from "../../fixtures/driver-test-suite/registry"; import type { DriverTestConfig } from "./mod"; import { createTestInlineClientDriver } from "./test-inline-client-driver"; +import { logger } from "./log"; export const FAKE_TIME = new Date("2024-01-01T00:00:00.000Z"); @@ -26,7 +26,10 @@ export async function setupDriverTest( // Build drivers const { endpoint, namespace, runnerName, cleanup } = await driverTestConfig.start(); - c.onTestFinished(cleanup); + c.onTestFinished(() => { + logger().info("cleaning up test"); + cleanup(); + }); let client: Client; if (driverTestConfig.clientType === "http") { diff --git a/scripts/run/docker/engine-rocksdb.sh b/scripts/run/docker/engine-rocksdb.sh index 42690862ba..fa553cafbd 100755 --- a/scripts/run/docker/engine-rocksdb.sh +++ b/scripts/run/docker/engine-rocksdb.sh @@ -7,6 +7,7 @@ REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" cd "${REPO_ROOT}" RUST_LOG="${RUST_LOG:-"opentelemetry_sdk=off,opentelemetry-otlp=info,tower::buffer::worker=info,debug"}" \ +RUST_LOG_TARGET=1 \ RIVET__PEGBOARD__RETRY_RESET_DURATION="100" \ RIVET__PEGBOARD__BASE_RETRY_TIMEOUT="100" \ RIVET__PEGBOARD__RESCHEDULE_BACKOFF_MAX_EXPONENT="1" \