Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions engine/artifacts/errors/ws.going_away.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions engine/packages/gasoline/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub(crate) const PING_INTERVAL: Duration = Duration::from_secs(10);
const METRICS_INTERVAL: Duration = Duration::from_secs(20);
// How long the pull workflows function can take before shutting down the runtime.
const PULL_WORKFLOWS_TIMEOUT: Duration = Duration::from_secs(10);
const SHUTDOWN_PROGRESS_INTERVAL: Duration = Duration::from_secs(7);

/// Used to spawn a new thread that indefinitely polls the database for new workflows. Only pulls workflows
/// that are registered in its registry. After pulling, the workflows are ran and their state is written to
Expand Down Expand Up @@ -297,6 +298,9 @@ impl Worker {
.map(|(_, wf)| &mut wf.handle)
.collect::<FuturesUnordered<_>>();

let mut progress_interval = tokio::time::interval(SHUTDOWN_PROGRESS_INTERVAL);
progress_interval.tick().await;

let shutdown_start = Instant::now();
loop {
// Future will resolve once all workflow tasks complete
Expand All @@ -306,6 +310,9 @@ impl Worker {
_ = join_fut => {
break;
}
_ = progress_interval.tick() => {
tracing::info!(remaining_workflows=%wf_futs.len(), "worker still shutting down");
}
abort = term_signal.recv() => {
if abort {
tracing::warn!("aborting worker shutdown");
Expand Down
4 changes: 4 additions & 0 deletions engine/packages/guard-core/src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2522,6 +2522,10 @@ impl ProxyServiceFactory {
pub async fn wait_idle(&self) {
self.state.tasks.wait_idle().await
}

pub fn remaining_tasks(&self) -> usize {
self.state.tasks.remaining_tasks()
}
}

fn add_proxy_headers_with_addr(
Expand Down
41 changes: 34 additions & 7 deletions engine/packages/guard-core/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::{
net::SocketAddr,
sync::Arc,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::{Duration, Instant},
};

Expand All @@ -15,6 +18,8 @@ use crate::cert_resolver::{CertResolverFn, create_tls_config};
use crate::metrics;
use crate::proxy_service::{CacheKeyFn, MiddlewareFn, ProxyServiceFactory, RoutingFn};

const SHUTDOWN_PROGRESS_INTERVAL: Duration = Duration::from_secs(7);

// Start the server
#[tracing::instrument(skip_all)]
pub async fn run_server(
Expand Down Expand Up @@ -248,26 +253,41 @@ pub async fn run_server(
}

let shutdown_duration = config.runtime.guard_shutdown_duration();
tracing::info!(duration=?shutdown_duration, "starting guard shutdown");

let remaining_tasks = http_factory.remaining_tasks()
+ https_factory
.as_ref()
.map(|f| f.remaining_tasks())
.unwrap_or(0);
tracing::info!(%remaining_tasks, hyper_shutdown=%false, duration=?shutdown_duration, "starting guard shutdown");

// Signifies that the hyper graceful shutdown completed
let hyper_shutdown = Arc::new(AtomicBool::new(false));

let hyper_shutdown2 = hyper_shutdown.clone();
let http_factory2 = http_factory.clone();
let https_factory2 = https_factory.clone();
let mut complete_fut = async move {
// Wait until remaining requests finish
graceful.shutdown().await;
hyper_shutdown2.store(true, Ordering::Release);

// Wait until remaining tasks finish
http_factory.wait_idle().await;
http_factory2.wait_idle().await;

if let Some(https_factory) = https_factory {
if let Some(https_factory) = https_factory2 {
https_factory.wait_idle().await;
}
}
.boxed();

let mut progress_interval = tokio::time::interval(SHUTDOWN_PROGRESS_INTERVAL);
progress_interval.tick().await;

let shutdown_start = Instant::now();
loop {
tokio::select! {
_ = &mut complete_fut => {
tracing::info!("all guard requests completed");
tracing::info!("all guard tasks completed");
break;
}
abort = term_signal.recv() => {
Expand All @@ -276,8 +296,15 @@ pub async fn run_server(
break;
}
}
_ = progress_interval.tick() => {
let remaining_tasks = http_factory.remaining_tasks() +
https_factory.as_ref().map(|f| f.remaining_tasks()).unwrap_or(0);
let hyper_shutdown = hyper_shutdown.load(Ordering::Acquire);

tracing::info!(%remaining_tasks, hyper_shutdown, "guard still shutting down");
}
_ = tokio::time::sleep(shutdown_duration.saturating_sub(shutdown_start.elapsed())) => {
tracing::warn!("guard shutdown timed out before all requests completed");
tracing::warn!("guard shutdown timed out before all tasks completed");
break;
}
}
Expand Down
4 changes: 4 additions & 0 deletions engine/packages/guard-core/src/task_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,8 @@ impl TaskGroup {
}
}
}

pub fn remaining_tasks(&self) -> usize {
self.running_count.load(Ordering::Acquire)
}
}
5 changes: 5 additions & 0 deletions engine/packages/pegboard-runner/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ pub enum WsError {
"The websocket has been evicted and should not attempt to reconnect."
)]
Eviction,
#[error(
"going_away",
"The Rivet Engine is migrating. The websocket should attempt to reconnect as soon as possible."
)]
GoingAway,
#[error(
"timed_out_waiting_for_init",
"Timed out waiting for the init packet to be sent."
Expand Down
14 changes: 13 additions & 1 deletion engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,17 @@ pub async fn task_inner(
event_demuxer: &mut ActorEventDemuxer,
) -> Result<LifecycleResult> {
let mut ws_rx = ws_rx.lock().await;
let mut term_signal = rivet_runtime::TermSignal::new().await;

loop {
match recv_msg(&mut ws_rx, &mut eviction_sub2, &mut ws_to_tunnel_abort_rx).await? {
match recv_msg(
&mut ws_rx,
&mut eviction_sub2,
&mut ws_to_tunnel_abort_rx,
&mut term_signal,
)
.await?
{
Ok(Some(msg)) => {
if protocol::is_mk2(conn.protocol_version) {
handle_message_mk2(&ctx, &conn, event_demuxer, msg).await?;
Expand All @@ -74,6 +82,7 @@ async fn recv_msg(
ws_rx: &mut MutexGuard<'_, WebSocketReceiver>,
eviction_sub2: &mut Subscriber,
ws_to_tunnel_abort_rx: &mut watch::Receiver<()>,
term_signal: &mut rivet_runtime::TermSignal,
) -> Result<std::result::Result<Option<Bytes>, LifecycleResult>> {
let msg = tokio::select! {
res = ws_rx.try_next() => {
Expand All @@ -92,6 +101,9 @@ async fn recv_msg(
tracing::debug!("task aborted");
return Ok(Err(LifecycleResult::Aborted));
}
_ = term_signal.recv() => {
return Err(errors::WsError::GoingAway.build());
}
};

match msg {
Expand Down
Loading