diff --git a/Cargo.lock b/Cargo.lock index 2507283d..7b3789b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1284,6 +1284,7 @@ dependencies = [ "hyper-util", "indexmap 2.11.0", "itoa", + "ntex-http", "ryu", "serde", "sonic-rs", @@ -1837,12 +1838,12 @@ dependencies = [ [[package]] name = "ntex-http" -version = "0.1.14" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3102673534f57dbc7fc9e7f1aac4126f353366c67518eac0a7763bb2515f0a7a" +checksum = "61da3d6c8bec83c5481d7e36ed4cf1a8cd0edee3e2fa411290932b17549d5cf2" dependencies = [ + "ahash", "futures-core", - "fxhash", "http", "itoa", "log", diff --git a/bin/dev-cli/src/main.rs b/bin/dev-cli/src/main.rs index 45fa31bd..e5c4135b 100644 --- a/bin/dev-cli/src/main.rs +++ b/bin/dev-cli/src/main.rs @@ -21,6 +21,7 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilte fn main() { let tree_layer = tracing_tree::HierarchicalLayer::new(2) + .with_writer(std::io::stdout) .with_bracketed_fields(true) .with_deferred_spans(false) .with_wraparound(25) diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index 8e126228..eb34deca 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -65,6 +65,7 @@ pub async fn execute_plan( query_plan: query_plan_payload, projection_plan: &normalized_payload.projection_plan, variable_values: &variable_payload.variables_map, + upstream_headers: req.headers(), extensions, introspection_context: &introspection_context, operation_type_name: normalized_payload.root_type_name, diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 9193459b..e09921fd 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -21,6 +21,7 @@ async-trait = { workspace = true } futures = { workspace = true } http = { workspace = true } http-body-util = { workspace = true } +ntex-http = "0.1.15" hyper = { workspace = true, features = ["client"] } serde = { workspace = true } sonic-rs = { workspace = true } diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index 200ee75e..85f0f462 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -6,6 +6,7 @@ use hive_router_query_planner::planner::plan_nodes::{ ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, PlanNode, QueryPlan, SequenceNode, }; +use ntex_http::HeaderMap; use serde::Deserialize; use sonic_rs::ValueRef; @@ -36,6 +37,7 @@ pub struct QueryPlanExecutionContext<'exec> { pub query_plan: &'exec QueryPlan, pub projection_plan: &'exec Vec, pub variable_values: &'exec Option>, + pub upstream_headers: &'exec HeaderMap, pub extensions: Option>, pub introspection_context: &'exec IntrospectionContext<'exec, 'static>, pub operation_type_name: &'exec str, @@ -60,6 +62,7 @@ pub async fn execute_query_plan<'exec>( ctx.introspection_context.metadata, // Deduplicate subgraph requests only if the operation type is a query ctx.operation_type_name == "Query", + ctx.upstream_headers, ); executor .execute(&mut exec_ctx, ctx.query_plan.node.as_ref()) @@ -83,6 +86,7 @@ pub struct Executor<'exec> { schema_metadata: &'exec SchemaMetadata, executors: &'exec SubgraphExecutorMap, dedupe_subgraph_requests: bool, + upstream_headers: &'exec HeaderMap, } struct ConcurrencyScope<'exec, T> { @@ -150,12 +154,14 @@ impl<'exec> Executor<'exec> { executors: &'exec SubgraphExecutorMap, schema_metadata: &'exec SchemaMetadata, dedupe_subgraph_requests: bool, + upstream_headers: &'exec HeaderMap, ) -> Self { Executor { variable_values, executors, schema_metadata, dedupe_subgraph_requests, + upstream_headers, } } @@ -533,6 +539,7 @@ impl<'exec> Executor<'exec> { operation_name: node.operation_name.as_deref(), variables: None, representations, + upstream_headers: self.upstream_headers, }, ) .await, diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index 7862a2c6..38f5aab9 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use bytes::Bytes; +use ntex_http::HeaderMap; #[async_trait] pub trait SubgraphExecutor { @@ -21,6 +22,7 @@ pub type SubgraphExecutorBoxedArc = Arc>; pub struct HttpExecutionRequest<'a> { pub query: &'a str, pub dedupe: bool, + pub upstream_headers: &'a HeaderMap, pub operation_name: Option<&'a str>, // TODO: variables could be stringified before even executing the request pub variables: Option>, diff --git a/lib/executor/src/executors/dedupe.rs b/lib/executor/src/executors/dedupe.rs index 8ab6b844..e713aa2b 100644 --- a/lib/executor/src/executors/dedupe.rs +++ b/lib/executor/src/executors/dedupe.rs @@ -1,8 +1,9 @@ use ahash::AHasher; use bytes::Bytes; use http::{HeaderMap, Method, StatusCode, Uri}; +use ntex_http::HeaderMap as NtexHeaderMap; use std::collections::BTreeMap; -use std::hash::{BuildHasherDefault, Hash, Hasher}; +use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; #[derive(Debug, Clone)] pub struct SharedResponse { @@ -11,66 +12,52 @@ pub struct SharedResponse { pub body: Bytes, } -#[derive(Debug, Clone, Eq)] -pub struct RequestFingerprint { - method: Method, - url: Uri, - /// BTreeMap to ensure case-insensitivity and consistent order for hashing - headers: BTreeMap, - body: Vec, -} +pub fn request_fingerprint( + method: &Method, + url: &Uri, + req_headers: &HeaderMap, + upstream_headers: &NtexHeaderMap, + body_bytes: &[u8], + fingerprint_headers: &[String], +) -> u64 { + let build_hasher = ABuildHasher::default(); + let mut hasher = build_hasher.build_hasher(); + + // BTreeMap to ensure case-insensitivity and consistent order for hashing + let mut headers = BTreeMap::new(); + if fingerprint_headers.is_empty() { + // fingerprint all headers -impl RequestFingerprint { - pub fn new( - method: &Method, - url: &Uri, - req_headers: &HeaderMap, - body_bytes: &[u8], - fingerprint_headers: &[String], - ) -> Self { - let mut headers = BTreeMap::new(); - if fingerprint_headers.is_empty() { - // fingerprint all headers - for (key, value) in req_headers.iter() { + for (key, value) in req_headers.iter() { + if let Ok(value_str) = value.to_str() { + headers.insert(key.as_str(), value_str); + } + } + for (key, value) in upstream_headers.iter() { + if let Ok(value_str) = value.to_str() { + headers.insert(key.as_str(), value_str); + } + } + } else { + for header_name in fingerprint_headers.iter() { + if let Some(value) = req_headers.get(header_name) { if let Ok(value_str) = value.to_str() { - headers.insert(key.as_str().to_lowercase(), value_str.to_string()); + headers.insert(header_name, value_str); } - } - } else { - for header_name in fingerprint_headers.iter() { - if let Some(value) = req_headers.get(header_name) { - if let Ok(value_str) = value.to_str() { - headers.insert(header_name.to_lowercase(), value_str.to_string()); - } + } else if let Some(value) = upstream_headers.get(header_name) { + if let Ok(value_str) = value.to_str() { + headers.insert(header_name, value_str); } } } - - Self { - method: method.clone(), - url: url.clone(), - headers, - body: body_bytes.to_vec(), - } } -} -impl Hash for RequestFingerprint { - fn hash(&self, state: &mut H) { - self.method.hash(state); - self.url.hash(state); - self.headers.hash(state); - self.body.hash(state); - } -} + method.hash(&mut hasher); + url.hash(&mut hasher); + headers.hash(&mut hasher); + body_bytes.hash(&mut hasher); -impl PartialEq for RequestFingerprint { - fn eq(&self, other: &Self) -> bool { - self.method == other.method - && self.url == other.url - && self.headers == other.headers - && self.body == other.body - } + hasher.finish() } pub type ABuildHasher = BuildHasherDefault; diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 559b3dd3..26a62757 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -1,6 +1,7 @@ use std::sync::Arc; -use crate::executors::dedupe::{ABuildHasher, RequestFingerprint, SharedResponse}; +use crate::executors::dedupe::request_fingerprint; +use crate::executors::dedupe::{ABuildHasher, SharedResponse}; use dashmap::DashMap; use hive_router_config::traffic_shaping::TrafficShapingExecutorConfig; use tokio::sync::OnceCell; @@ -33,8 +34,7 @@ pub struct HTTPSubgraphExecutor { pub header_map: HeaderMap, pub semaphore: Arc, pub config: Arc, - pub in_flight_requests: - Arc>, ABuildHasher>>, + pub in_flight_requests: Arc>, ABuildHasher>>, } const FIRST_VARIABLE_STR: &[u8] = b",\"variables\":{"; @@ -46,9 +46,7 @@ impl HTTPSubgraphExecutor { http_client: Arc, Full>>, semaphore: Arc, config: Arc, - in_flight_requests: Arc< - DashMap>, ABuildHasher>, - >, + in_flight_requests: Arc>, ABuildHasher>>, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -184,10 +182,11 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { }; } - let fingerprint = RequestFingerprint::new( + let fingerprint = request_fingerprint( &http::Method::POST, &self.endpoint, &self.header_map, + execution_request.upstream_headers, &body, &self.config.dedupe_fingerprint_headers, ); @@ -196,7 +195,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { // Prevents any deadlocks. let cell = self .in_flight_requests - .entry(fingerprint.clone()) + .entry(fingerprint) .or_default() .value() .clone(); diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 69d7e7de..37431b29 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -14,7 +14,7 @@ use tokio::sync::{OnceCell, Semaphore}; use crate::{ executors::{ common::{HttpExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc}, - dedupe::{ABuildHasher, RequestFingerprint, SharedResponse}, + dedupe::{ABuildHasher, SharedResponse}, error::SubgraphExecutorError, http::HTTPSubgraphExecutor, }, @@ -81,9 +81,8 @@ impl SubgraphExecutorMap { let semaphores_by_origin: DashMap> = DashMap::new(); let max_connections_per_host = config.max_connections_per_host; let config_arc = Arc::new(config); - let in_flight_requests: Arc< - DashMap>, ABuildHasher>, - > = Arc::new(DashMap::with_hasher(ABuildHasher::default())); + let in_flight_requests: Arc>, ABuildHasher>> = + Arc::new(DashMap::with_hasher(ABuildHasher::default())); let executor_map = subgraph_endpoint_map .into_iter() diff --git a/lib/router-config/src/traffic_shaping.rs b/lib/router-config/src/traffic_shaping.rs index eec853fd..f7e4310c 100644 --- a/lib/router-config/src/traffic_shaping.rs +++ b/lib/router-config/src/traffic_shaping.rs @@ -21,7 +21,10 @@ pub struct TrafficShapingExecutorConfig { /// A list of headers that should be used to fingerprint requests for deduplication. /// /// If not provided, the default is to use the "authorization" header only. - #[serde(default = "default_dedupe_fingerprint_headers")] + #[serde( + default = "default_dedupe_fingerprint_headers", + deserialize_with = "deserialize_and_normalize_dedupe_fingerprint_headers" + )] pub dedupe_fingerprint_headers: Vec, } @@ -51,3 +54,17 @@ fn default_dedupe_enabled() -> bool { fn default_dedupe_fingerprint_headers() -> Vec { vec!["authorization".to_string()] } + +fn deserialize_and_normalize_dedupe_fingerprint_headers<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let headers: Vec = Deserialize::deserialize(deserializer)?; + Ok(normalize_dedupe_fingerprint_headers(headers)) +} + +fn normalize_dedupe_fingerprint_headers(headers: Vec) -> Vec { + headers.into_iter().map(|h| h.to_lowercase()).collect() +}