From 3792d39af6875ffd325793b533dd4b614e555cc9 Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Mon, 8 Jul 2024 14:54:53 -0400 Subject: [PATCH] Use a drop guard to log aborted futures --- ui/src/request_database.rs | 36 +++++++++++++++++++-- ui/src/server_axum.rs | 8 ++--- ui/src/server_axum/websocket.rs | 57 ++++++++++++++------------------- 3 files changed, 61 insertions(+), 40 deletions(-) diff --git a/ui/src/request_database.rs b/ui/src/request_database.rs index a45686fe..918873bd 100644 --- a/ui/src/request_database.rs +++ b/ui/src/request_database.rs @@ -174,7 +174,7 @@ impl Handle { rx.await.context(RecvStartRequestSnafu)?.map_err(Into::into) } - pub async fn attempt_start_request( + async fn attempt_start_request( &self, category: impl Into, payload: impl Into, @@ -200,11 +200,43 @@ impl Handle { rx.await.context(RecvEndRequestSnafu)?.map_err(Into::into) } - pub async fn attempt_end_request(&self, id: Id, how: How) { + async fn attempt_end_request(&self, id: Id, how: How) { if let Err(err) = self.end_request(id, how).await { warn!(?err, "Unable to record end request"); } } + + pub async fn start_with_guard( + self, + category: impl Into, + payload: impl Into, + ) -> EndGuard { + let g = self + .attempt_start_request(category, payload) + .await + .map(|id| EndGuardInner(id, How::Abandoned, self)); + EndGuard(g) + } +} + +pub struct EndGuard(Option); + +impl EndGuard { + pub fn complete_now(mut self) { + if let Some(mut inner) = self.0.take() { + inner.1 = How::Complete; + drop(inner); + } + } +} + +struct EndGuardInner(Id, How, Handle); + +impl Drop for EndGuardInner { + fn drop(&mut self) { + let Self(id, how, ref handle) = *self; + futures::executor::block_on(handle.attempt_end_request(id, how)) + } } #[derive(Debug, Snafu)] diff --git a/ui/src/server_axum.rs b/ui/src/server_axum.rs index f4b8d31e..f29362d1 100644 --- a/ui/src/server_axum.rs +++ b/ui/src/server_axum.rs @@ -4,7 +4,7 @@ use crate::{ record_metric, track_metric_no_request_async, Endpoint, HasLabelsCore, Outcome, UNAVAILABLE_WS, }, - request_database::{Handle, How}, + request_database::Handle, Config, GhToken, MetricsToken, }; use async_trait::async_trait; @@ -198,13 +198,11 @@ where { let category = format!("http.{}", <&str>::from(T::ENDPOINT)); let payload = serde_json::to_string(&req).unwrap_or_else(|_| String::from("")); - let id = db.attempt_start_request(category, payload).await; + let guard = db.start_with_guard(category, payload).await; let r = f(req).await; - if let Some(id) = id { - db.attempt_end_request(id, How::Complete).await; - } + guard.complete_now(); r } diff --git a/ui/src/server_axum/websocket.rs b/ui/src/server_axum/websocket.rs index d86b6232..b8b6409d 100644 --- a/ui/src/server_axum/websocket.rs +++ b/ui/src/server_axum/websocket.rs @@ -1,6 +1,6 @@ use crate::{ metrics::{self, record_metric, Endpoint, HasLabelsCore, Outcome}, - request_database::{Handle, How}, + request_database::Handle, server_axum::api_orchestrator_integration_impls::*, }; @@ -389,12 +389,6 @@ async fn handle_core( resp = rx.recv() => { let resp = resp.expect("The rx should never close as we have a tx"); - if let Ok(MessageResponse::ExecuteEnd { meta, .. }) = &resp { - if let Some((_, _, Some(db_id))) = active_executions.get(&meta.sequence_number) { - db.attempt_end_request(*db_id, How::Complete).await; - } - } - let success = resp.is_ok(); let resp = resp.unwrap_or_else(error_to_response); let resp = response_to_message(resp); @@ -443,7 +437,7 @@ async fn handle_core( _ = active_execution_gc_interval.tick() => { active_executions = mem::take(&mut active_executions) .into_iter() - .filter(|(_id, (_, tx, _))| tx.as_ref().map_or(false, |tx| !tx.is_closed())) + .filter(|(_id, (_, tx))| tx.as_ref().map_or(false, |tx| !tx.is_closed())) .collect(); }, @@ -464,12 +458,6 @@ async fn handle_core( } } - for (_, (_, _, db_id)) in active_executions { - if let Some(db_id) = db_id { - db.attempt_end_request(db_id, How::Abandoned).await; - } - } - drop((tx, rx, socket)); if let Err(e) = manager.shutdown().await { error!("Could not shut down the Coordinator: {e:?}"); @@ -516,11 +504,7 @@ fn response_to_message(response: MessageResponse) -> Message { Message::Text(resp) } -type ActiveExecutionInfo = ( - CancellationToken, - Option>, - Option, -); +type ActiveExecutionInfo = (CancellationToken, Option>); async fn handle_msg( txt: String, @@ -538,22 +522,31 @@ async fn handle_msg( let token = CancellationToken::new(); let (execution_tx, execution_rx) = mpsc::channel(8); - let id = db.attempt_start_request("ws.Execute", &txt).await; + let guard = db.clone().start_with_guard("ws.Execute", &txt).await; - active_executions.insert( - meta.sequence_number, - (token.clone(), Some(execution_tx), id), - ); + active_executions.insert(meta.sequence_number, (token.clone(), Some(execution_tx))); // TODO: Should a single execute / build / etc. session have a timeout of some kind? let spawned = manager .spawn({ let tx = tx.clone(); let meta = meta.clone(); - |coordinator| { - handle_execute(token, execution_rx, tx, coordinator, payload, meta.clone()) - .context(StreamingExecuteSnafu) - .map_err(|e| (e, Some(meta))) + |coordinator| async { + let r = handle_execute( + token, + execution_rx, + tx, + coordinator, + payload, + meta.clone(), + ) + .context(StreamingExecuteSnafu) + .map_err(|e| (e, Some(meta))) + .await; + + guard.complete_now(); + + r } }) .await @@ -565,8 +558,7 @@ async fn handle_msg( } Ok(ExecuteStdin { payload, meta }) => { - let Some((_, Some(execution_tx), _)) = active_executions.get(&meta.sequence_number) - else { + let Some((_, Some(execution_tx))) = active_executions.get(&meta.sequence_number) else { warn!("Received stdin for an execution that is no longer active"); return; }; @@ -582,8 +574,7 @@ async fn handle_msg( } Ok(ExecuteStdinClose { meta }) => { - let Some((_, execution_tx, _)) = active_executions.get_mut(&meta.sequence_number) - else { + let Some((_, execution_tx)) = active_executions.get_mut(&meta.sequence_number) else { warn!("Received stdin close for an execution that is no longer active"); return; }; @@ -592,7 +583,7 @@ async fn handle_msg( } Ok(ExecuteKill { meta }) => { - let Some((token, _, _)) = active_executions.get(&meta.sequence_number) else { + let Some((token, _)) = active_executions.get(&meta.sequence_number) else { warn!("Received kill for an execution that is no longer active"); return; };