diff --git a/src/render.rs b/src/render.rs index aecd118b..c9229ce7 100644 --- a/src/render.rs +++ b/src/render.rs @@ -80,7 +80,7 @@ pub enum PageContext { /// Handles the first SQL statements, before the headers have been sent to pub struct HeaderContext { app_state: Arc, - request_context: RequestContext, + pub request_context: RequestContext, pub writer: ResponseWriter, response: HttpResponseBuilder, has_status: bool, @@ -368,7 +368,14 @@ impl HeaderContext { Ok(PageContext::Header(self)) } - async fn start_body(self, data: JsonValue) -> anyhow::Result { + fn add_server_timing_header(&mut self) { + if let Some(header_value) = self.request_context.server_timing.header_value() { + self.response.insert_header(("Server-Timing", header_value)); + } + } + + async fn start_body(mut self, data: JsonValue) -> anyhow::Result { + self.add_server_timing_header(); let html_renderer = HtmlRenderContext::new(self.app_state, self.request_context, self.writer, data) .await @@ -382,6 +389,7 @@ impl HeaderContext { } pub fn close(mut self) -> HttpResponse { + self.add_server_timing_header(); self.response.finish() } } diff --git a/src/webserver/database/execute_queries.rs b/src/webserver/database/execute_queries.rs index 80883e24..9f780f72 100644 --- a/src/webserver/database/execute_queries.rs +++ b/src/webserver/database/execute_queries.rs @@ -52,13 +52,13 @@ pub fn stream_query_results_with_conn<'a>( for res in &sql_file.statements { match res { ParsedStatement::CsvImport(csv_import) => { - let connection = take_connection(&request.app_state.db, db_connection).await?; + let connection = take_connection(&request.app_state.db, db_connection, request).await?; log::debug!("Executing CSV import: {csv_import:?}"); run_csv_import(connection, csv_import, request).await.with_context(|| format!("Failed to import the CSV file {:?} into the table {:?}", csv_import.uploaded_file, csv_import.table_name))?; }, ParsedStatement::StmtWithParams(stmt) => { let query = bind_parameters(stmt, request, db_connection).await?; - let connection = take_connection(&request.app_state.db, db_connection).await?; + let connection = take_connection(&request.app_state.db, db_connection, request).await?; log::trace!("Executing query {:?}", query.sql); let mut stream = connection.fetch_many(query); let mut error = None; @@ -192,7 +192,7 @@ async fn execute_set_variable_query<'a>( source_file: &Path, ) -> anyhow::Result<()> { let query = bind_parameters(statement, request, db_connection).await?; - let connection = take_connection(&request.app_state.db, db_connection).await?; + let connection = take_connection(&request.app_state.db, db_connection, request).await?; log::debug!( "Executing query to set the {variable:?} variable: {:?}", query.sql @@ -276,6 +276,7 @@ fn vars_and_name<'a, 'b>( async fn take_connection<'a>( db: &'a Database, conn: &'a mut DbConn, + request: &RequestInfo, ) -> anyhow::Result<&'a mut PoolConnection> { if let Some(c) = conn { return Ok(c); @@ -283,6 +284,7 @@ async fn take_connection<'a>( match db.connection.acquire().await { Ok(c) => { log::debug!("Acquired a database connection"); + request.server_timing.record("db_conn"); *conn = Some(c); Ok(conn.as_mut().unwrap()) } diff --git a/src/webserver/http.rs b/src/webserver/http.rs index 88c9966e..dbf3e208 100644 --- a/src/webserver/http.rs +++ b/src/webserver/http.rs @@ -7,6 +7,7 @@ use crate::webserver::content_security_policy::ContentSecurityPolicy; use crate::webserver::database::execute_queries::stop_at_first_error; use crate::webserver::database::{execute_queries::stream_query_results_with_conn, DbItem}; use crate::webserver::http_request_info::extract_request_info; +use crate::webserver::server_timing::ServerTiming; use crate::webserver::ErrorWithStatus; use crate::{AppConfig, AppState, ParsedSqlFile, DEFAULT_404_FILE}; use actix_web::dev::{fn_service, ServiceFactory, ServiceRequest}; @@ -46,6 +47,7 @@ pub struct RequestContext { pub is_embedded: bool, pub source_path: PathBuf, pub content_security_policy: ContentSecurityPolicy, + pub server_timing: Arc, } async fn stream_response(stream: impl Stream, mut renderer: AnyRenderBodyContext) { @@ -106,7 +108,10 @@ async fn build_response_header_and_stream>( let mut stream = Box::pin(database_entries); while let Some(item) = stream.next().await { let page_context = match item { - DbItem::Row(data) => head_context.handle_row(data).await?, + DbItem::Row(data) => { + head_context.request_context.server_timing.record("row"); + head_context.handle_row(data).await? + } DbItem::FinishedQuery => { log::debug!("finished query"); continue; @@ -163,18 +168,21 @@ enum ResponseWithWriter { async fn render_sql( srv_req: &mut ServiceRequest, sql_file: Arc, + server_timing: ServerTiming, ) -> actix_web::Result { let app_state = srv_req .app_data::>() .ok_or_else(|| ErrorInternalServerError("no state"))? - .clone() // Cheap reference count increase + .clone() .into_inner(); - let mut req_param = extract_request_info(srv_req, Arc::clone(&app_state)) + let mut req_param = extract_request_info(srv_req, Arc::clone(&app_state), server_timing) .await .map_err(|e| anyhow_err_to_actix(e, &app_state))?; log::debug!("Received a request with the following parameters: {req_param:?}"); + req_param.server_timing.record("parse_req"); + let (resp_send, resp_recv) = tokio::sync::oneshot::channel::(); let source_path: PathBuf = sql_file.source_path.clone(); actix_web::rt::spawn(async move { @@ -182,6 +190,7 @@ async fn render_sql( is_embedded: req_param.get_variables.contains_key("_sqlpage_embed"), source_path, content_security_policy: ContentSecurityPolicy::with_random_nonce(), + server_timing: Arc::clone(&req_param.server_timing), }; let mut conn = None; let database_entries_stream = @@ -275,13 +284,17 @@ async fn process_sql_request( sql_path: PathBuf, ) -> actix_web::Result { let app_state: &web::Data = req.app_data().expect("app_state"); + let server_timing = ServerTiming::for_env(app_state.config.environment); + let sql_file = app_state .sql_file_cache .get_with_privilege(app_state, &sql_path, false) .await .with_context(|| format!("Unable to read SQL file \"{}\"", sql_path.display())) .map_err(|e| anyhow_err_to_actix(e, app_state))?; - render_sql(req, sql_file).await + server_timing.record("sql_file"); + + render_sql(req, sql_file, server_timing).await } async fn serve_file( diff --git a/src/webserver/http_request_info.rs b/src/webserver/http_request_info.rs index 23675a51..ff0f3114 100644 --- a/src/webserver/http_request_info.rs +++ b/src/webserver/http_request_info.rs @@ -1,3 +1,4 @@ +use crate::webserver::server_timing::ServerTiming; use crate::AppState; use actix_multipart::form::bytes::Bytes; use actix_multipart::form::tempfile::TempFile; @@ -42,6 +43,7 @@ pub struct RequestInfo { pub clone_depth: u8, pub raw_body: Option>, pub oidc_claims: Option, + pub server_timing: Arc, } impl RequestInfo { @@ -62,6 +64,7 @@ impl RequestInfo { clone_depth: self.clone_depth + 1, raw_body: self.raw_body.clone(), oidc_claims: self.oidc_claims.clone(), + server_timing: Arc::clone(&self.server_timing), } } } @@ -78,6 +81,7 @@ impl Clone for RequestInfo { pub(crate) async fn extract_request_info( req: &mut ServiceRequest, app_state: Arc, + server_timing: ServerTiming, ) -> anyhow::Result { let (http_req, payload) = req.parts_mut(); let method = http_req.method().clone(); @@ -123,6 +127,7 @@ pub(crate) async fn extract_request_info( clone_depth: 0, raw_body, oidc_claims, + server_timing: Arc::new(server_timing), }) } @@ -275,7 +280,7 @@ async fn is_file_field_empty( mod test { use super::super::http::SingleOrVec; use super::*; - use crate::app_config::AppConfig; + use crate::{app_config::AppConfig, webserver::server_timing::ServerTiming}; use actix_web::{http::header::ContentType, test::TestRequest}; #[actix_web::test] @@ -284,7 +289,8 @@ mod test { serde_json::from_str::(r#"{"listen_on": "localhost:1234"}"#).unwrap(); let mut service_request = TestRequest::default().to_srv_request(); let app_data = Arc::new(AppState::init(&config).await.unwrap()); - let request_info = extract_request_info(&mut service_request, app_data) + let server_timing = ServerTiming::default(); + let request_info = extract_request_info(&mut service_request, app_data, server_timing) .await .unwrap(); assert_eq!(request_info.post_variables.len(), 0); @@ -302,7 +308,8 @@ mod test { .set_payload("my_array[]=3&my_array[]=Hello%20World&repeated=1&repeated=2") .to_srv_request(); let app_data = Arc::new(AppState::init(&config).await.unwrap()); - let request_info = extract_request_info(&mut service_request, app_data) + let server_timing = ServerTiming::default(); + let request_info = extract_request_info(&mut service_request, app_data, server_timing) .await .unwrap(); assert_eq!( @@ -351,7 +358,8 @@ mod test { ) .to_srv_request(); let app_data = Arc::new(AppState::init(&config).await.unwrap()); - let request_info = extract_request_info(&mut service_request, app_data) + let server_timing = ServerTiming::enabled(false); + let request_info = extract_request_info(&mut service_request, app_data, server_timing) .await .unwrap(); assert_eq!( diff --git a/src/webserver/mod.rs b/src/webserver/mod.rs index 9a74cc04..484ad40d 100644 --- a/src/webserver/mod.rs +++ b/src/webserver/mod.rs @@ -38,6 +38,7 @@ pub mod http_client; pub mod http_request_info; mod https; pub mod request_variables; +pub mod server_timing; pub use database::Database; pub use error_with_status::ErrorWithStatus; diff --git a/src/webserver/server_timing.rs b/src/webserver/server_timing.rs new file mode 100644 index 00000000..44ad4683 --- /dev/null +++ b/src/webserver/server_timing.rs @@ -0,0 +1,70 @@ +use std::fmt::Write; +use std::sync::Mutex; +use std::time::Instant; + +use crate::app_config::DevOrProd; + +#[derive(Debug)] +pub struct ServerTiming { + enabled: bool, + created_at: Instant, + events: Mutex>, +} + +#[derive(Debug)] +struct PerfEvent { + time: Instant, + name: &'static str, +} + +impl Default for ServerTiming { + fn default() -> Self { + Self { + enabled: false, + created_at: Instant::now(), + events: Mutex::new(Vec::new()), + } + } +} + +impl ServerTiming { + #[must_use] + pub fn enabled(enabled: bool) -> Self { + Self { + enabled, + ..Default::default() + } + } + + #[must_use] + pub fn for_env(env: DevOrProd) -> Self { + Self::enabled(!env.is_prod()) + } + + pub fn record(&self, name: &'static str) { + if self.enabled { + self.events.lock().unwrap().push(PerfEvent { + time: Instant::now(), + name, + }); + } + } + + pub fn header_value(&self) -> Option { + if !self.enabled { + return None; + } + let evts = self.events.lock().unwrap(); + let mut s = String::with_capacity(evts.len() * 16); + let mut last = self.created_at; + for (i, PerfEvent { name, time }) in evts.iter().enumerate() { + if i > 0 { + s.push_str(", "); + } + let millis = time.saturating_duration_since(last).as_millis(); + write!(&mut s, "{name};dur={millis}").ok()?; + last = *time; + } + Some(s) + } +} diff --git a/tests/mod.rs b/tests/mod.rs index 9eab13f2..aeaf5dc9 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -5,6 +5,7 @@ mod core; mod data_formats; mod errors; mod requests; +mod server_timing; pub mod sql_test_files; mod transactions; mod uploads; diff --git a/tests/server_timing/mod.rs b/tests/server_timing/mod.rs new file mode 100644 index 00000000..a64eef23 --- /dev/null +++ b/tests/server_timing/mod.rs @@ -0,0 +1,100 @@ +use actix_web::http::StatusCode; +use sqlpage::webserver::http::main_handler; + +use crate::common::{get_request_to, make_app_data_from_config, test_config}; + +#[actix_web::test] +async fn test_server_timing_disabled_in_production() -> actix_web::Result<()> { + let mut config = test_config(); + config.environment = sqlpage::app_config::DevOrProd::Production; + let app_data = make_app_data_from_config(config).await; + + let req = crate::common::get_request_to_with_data( + "/tests/sql_test_files/it_works_simple.sql", + app_data, + ) + .await? + .to_srv_request(); + let resp = main_handler(req).await?; + + assert_eq!(resp.status(), StatusCode::OK); + assert!( + resp.headers().get("Server-Timing").is_none(), + "Server-Timing header should not be present in production mode" + ); + Ok(()) +} + +#[actix_web::test] +async fn test_server_timing_enabled_in_development() -> actix_web::Result<()> { + let mut config = test_config(); + config.environment = sqlpage::app_config::DevOrProd::Development; + let app_data = make_app_data_from_config(config).await; + + let req = crate::common::get_request_to_with_data( + "/tests/sql_test_files/it_works_sqrt.sql", + app_data, + ) + .await? + .to_srv_request(); + let resp = main_handler(req).await?; + + assert_eq!(resp.status(), StatusCode::OK); + let server_timing_header = resp + .headers() + .get("Server-Timing") + .expect("Server-Timing header should be present in development mode"); + let header_value = server_timing_header.to_str().unwrap(); + + assert!( + header_value.contains("sql_file;dur="), + "Should contain sql_file timing: {header_value}" + ); + assert!( + header_value.contains("parse_req;dur="), + "Should contain parse_req timing: {header_value}" + ); + assert!( + header_value.contains("db_conn;dur="), + "Should contain db_conn timing: {header_value}" + ); + assert!( + header_value.contains("row;dur="), + "Should contain row timing: {header_value}" + ); + + Ok(()) +} + +#[actix_web::test] +async fn test_server_timing_format() -> actix_web::Result<()> { + let req = get_request_to("/tests/sql_test_files/it_works_sqrt.sql") + .await? + .to_srv_request(); + let resp = main_handler(req).await?; + + assert_eq!(resp.status(), StatusCode::OK); + let server_timing_header = resp.headers().get("Server-Timing").unwrap(); + let header_value = server_timing_header.to_str().unwrap(); + + let parts: Vec<&str> = header_value.split(", ").collect(); + assert!(parts.len() >= 4, "Should have at least 4 timing events"); + + for part in parts { + assert!( + part.contains(";dur="), + "Each part should have name;dur= format: {part}" + ); + let dur_parts: Vec<&str> = part.split(";dur=").collect(); + assert_eq!(dur_parts.len(), 2, "Should have name and duration: {part}"); + let duration: f64 = dur_parts[1] + .parse() + .expect("Duration should be a valid number"); + assert!( + duration >= 0.0, + "Duration should be non-negative: {duration}" + ); + } + + Ok(()) +}