diff --git a/Cargo.lock b/Cargo.lock index 454e368c..0108e30b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -731,6 +731,7 @@ dependencies = [ name = "executor" version = "0.0.1" dependencies = [ + "async-graphql", "async-trait", "bumpalo", "bytes", @@ -743,14 +744,17 @@ dependencies = [ "hyper", "hyper-util", "indexmap 2.10.0", + "insta", "itoa", "query-planner", "ryu", "serde", + "serde_json", "sonic-rs", "subgraphs", "thiserror 2.0.14", "tokio", + "tokio-test", "tracing", "xxhash-rust", ] @@ -2477,6 +2481,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-tungstenite" version = "0.26.2" diff --git a/bin/gateway/src/pipeline/error.rs b/bin/gateway/src/pipeline/error.rs index b97866c3..4a7f9fe1 100644 --- a/bin/gateway/src/pipeline/error.rs +++ b/bin/gateway/src/pipeline/error.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use axum::{body::Body, extract::rejection::QueryRejection, response::IntoResponse}; use executor::{execution::error::PlanExecutionError, response::graphql_error::GraphQLError}; @@ -6,7 +6,7 @@ use graphql_tools::validation::utils::ValidationError; use http::{HeaderName, Method, Request, Response, StatusCode}; use query_planner::{ast::normalization::error::NormalizationError, planner::PlannerError}; use serde::{Deserialize, Serialize}; -use sonic_rs::{object, Value}; +use sonic_rs::Value; use crate::pipeline::header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}; @@ -156,8 +156,10 @@ impl IntoResponse for PipelineError { let code = self.error.graphql_error_code(); let message = self.error.graphql_error_message(); + let mut extensions = HashMap::new(); + extensions.insert("code".to_string(), Value::from(code)); let graphql_error = GraphQLError { - extensions: Some(Value::from_iter(&object! {"code": code.to_string()})), + extensions: Some(extensions), message, path: None, locations: None, diff --git a/bin/gateway/src/pipeline/execution_service.rs b/bin/gateway/src/pipeline/execution_service.rs index 2b611359..f04eacc4 100644 --- a/bin/gateway/src/pipeline/execution_service.rs +++ b/bin/gateway/src/pipeline/execution_service.rs @@ -17,7 +17,7 @@ use crate::shared_state::GatewaySharedState; use axum::body::Body; use axum::response::IntoResponse; use executor::execute_query_plan; -use executor::execution::plan::QueryPlanExecutionContext; +use executor::execution::plan::ExecuteQueryPlanParams; use executor::introspection::resolve::IntrospectionContext; use http::{HeaderName, HeaderValue, Request, Response}; use tower::Service; @@ -105,7 +105,7 @@ impl Service> for ExecutionService { metadata: &app_state.schema_metadata, }; - match execute_query_plan(QueryPlanExecutionContext { + match execute_query_plan(ExecuteQueryPlanParams { query_plan: &query_plan_payload.query_plan, projection_plan: &normalized_payload.projection_plan, variable_values: &variable_payload.variables_map, diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 676329f2..d6aeeed7 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -38,6 +38,10 @@ bumpalo = "3.19.0" subgraphs = { path = "../../bench/subgraphs" } criterion = { workspace = true } tokio = { workspace = true } +async-graphql = "7.0.17" +tokio-test = "0.4.4" +insta = "1.42.1" +serde_json = "1.0.109" [[bench]] name = "executor_benches" diff --git a/lib/executor/src/context.rs b/lib/executor/src/context.rs index b756ca0d..a39db524 100644 --- a/lib/executor/src/context.rs +++ b/lib/executor/src/context.rs @@ -4,16 +4,16 @@ use query_planner::planner::plan_nodes::{FetchNode, FetchRewrite, QueryPlan}; use crate::response::{graphql_error::GraphQLError, storage::ResponsesStorage, value::Value}; -pub struct ExecutionContext<'a> { +pub struct QueryPlanExecutionContext<'a> { pub response_storage: ResponsesStorage, pub final_response: Value<'a>, pub errors: Vec, pub output_rewrites: OutputRewritesStorage, } -impl<'a> Default for ExecutionContext<'a> { +impl<'a> Default for QueryPlanExecutionContext<'a> { fn default() -> Self { - ExecutionContext { + QueryPlanExecutionContext { response_storage: Default::default(), output_rewrites: Default::default(), errors: Vec::new(), @@ -22,23 +22,15 @@ impl<'a> Default for ExecutionContext<'a> { } } -impl<'a> ExecutionContext<'a> { +impl<'a> QueryPlanExecutionContext<'a> { pub fn new(query_plan: &QueryPlan, init_final_response: Value<'a>) -> Self { - ExecutionContext { + QueryPlanExecutionContext { response_storage: ResponsesStorage::new(), output_rewrites: OutputRewritesStorage::from_query_plan(query_plan), errors: Vec::new(), final_response: init_final_response, } } - - pub fn handle_errors(&mut self, errors: Option>) { - if let Some(errors) = errors { - for error in errors { - self.errors.push(error); - } - } - } } #[derive(Default)] diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index cdf86499..7dd98ba8 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -1,18 +1,18 @@ -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use bytes::{BufMut, Bytes, BytesMut}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use query_planner::planner::plan_nodes::{ - ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, PlanNode, - QueryPlan, SequenceNode, + ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, FlattenNodePathSegment, + ParallelNode, PlanNode, QueryPlan, SequenceNode, }; use serde::Deserialize; use sonic_rs::ValueRef; use crate::{ - context::ExecutionContext, + context::QueryPlanExecutionContext, execution::{error::PlanExecutionError, rewrites::FetchRewriteExt}, - executors::{common::HttpExecutionRequest, map::SubgraphExecutorMap}, + executors::{common::SubgraphExecutionRequest, map::SubgraphExecutorMap}, introspection::{ resolve::{resolve_introspection, IntrospectionContext}, schema::SchemaMetadata, @@ -23,16 +23,19 @@ use crate::{ response::project_by_operation, }, response::{ - graphql_error::GraphQLError, merge::deep_merge, subgraph_response::SubgraphResponse, + error_normalization::{add_subgraph_info_to_error, normalize_error_for_representation}, + graphql_error::GraphQLError, + merge::deep_merge, + subgraph_response::SubgraphResponse, value::Value, }, utils::{ - consts::{CLOSE_BRACKET, OPEN_BRACKET}, - traverse::{traverse_and_callback, traverse_and_callback_mut}, + consts::{CLOSE_BRACKET, OPEN_BRACKET, TYPENAME_FIELD_NAME}, + traverse::traverse_and_callback, }, }; -pub struct QueryPlanExecutionContext<'exec> { +pub struct ExecuteQueryPlanParams<'exec> { pub query_plan: &'exec QueryPlan, pub projection_plan: &'exec Vec, pub variable_values: &'exec Option>, @@ -43,7 +46,7 @@ pub struct QueryPlanExecutionContext<'exec> { } pub async fn execute_query_plan<'exec>( - ctx: QueryPlanExecutionContext<'exec>, + ctx: ExecuteQueryPlanParams<'exec>, ) -> Result { let init_value = if let Some(introspection_query) = ctx.introspection_context.query { resolve_introspection(introspection_query, ctx.introspection_context) @@ -51,10 +54,10 @@ pub async fn execute_query_plan<'exec>( Value::Null }; - let mut exec_ctx = ExecutionContext::new(ctx.query_plan, init_value); + let mut exec_ctx = QueryPlanExecutionContext::new(ctx.query_plan, init_value); if ctx.query_plan.node.is_some() { - let executor = Executor::new( + let executor = QueryPlanExecutor::new( ctx.variable_values, ctx.executors, ctx.introspection_context.metadata, @@ -76,47 +79,25 @@ pub async fn execute_query_plan<'exec>( .map_err(|e| e.into()) } -pub struct Executor<'exec> { +pub struct QueryPlanExecutor<'exec> { variable_values: &'exec Option>, schema_metadata: &'exec SchemaMetadata, - executors: &'exec SubgraphExecutorMap, -} - -struct ConcurrencyScope<'exec, T> { - jobs: FuturesUnordered>, -} - -impl<'exec, T> ConcurrencyScope<'exec, T> { - fn new() -> Self { - Self { - jobs: FuturesUnordered::new(), - } - } - - fn spawn(&mut self, future: BoxFuture<'exec, T>) { - self.jobs.push(future); - } - - async fn join_all(mut self) -> Vec { - let mut results = Vec::with_capacity(self.jobs.len()); - while let Some(result) = self.jobs.next().await { - results.push(result); - } - results - } + subgraph_executors: &'exec SubgraphExecutorMap, } struct FetchJob { + subgraph_name: String, fetch_node_id: i64, response: Bytes, } struct FlattenFetchJob { + subgraph_name: String, flatten_node_path: FlattenNodePath, response: Bytes, fetch_node_id: i64, representation_hashes: Vec, - representation_hash_to_index: HashMap, + filtered_representations_hashes: HashMap>>, } enum ExecutionJob { @@ -138,23 +119,27 @@ impl From for Bytes { struct PreparedFlattenData { representations: BytesMut, representation_hashes: Vec, - representation_hash_to_index: HashMap, + filtered_representations_hashes: HashMap>>, } -impl<'exec> Executor<'exec> { +impl<'exec> QueryPlanExecutor<'exec> { pub fn new( variable_values: &'exec Option>, executors: &'exec SubgraphExecutorMap, schema_metadata: &'exec SchemaMetadata, ) -> Self { - Executor { + QueryPlanExecutor { variable_values, - executors, + subgraph_executors: executors, schema_metadata, } } - pub async fn execute(&self, ctx: &mut ExecutionContext<'exec>, plan: Option<&PlanNode>) { + pub async fn execute( + &self, + ctx: &mut QueryPlanExecutionContext<'exec>, + plan: Option<&PlanNode>, + ) { match plan { Some(PlanNode::Fetch(node)) => self.execute_fetch_wave(ctx, node).await, Some(PlanNode::Parallel(node)) => self.execute_parallel_wave(ctx, node).await, @@ -167,7 +152,11 @@ impl<'exec> Executor<'exec> { } } - async fn execute_fetch_wave(&self, ctx: &mut ExecutionContext<'exec>, node: &FetchNode) { + async fn execute_fetch_wave( + &self, + ctx: &mut QueryPlanExecutionContext<'exec>, + node: &FetchNode, + ) { match self.execute_fetch_node(node, None).await { Ok(result) => self.process_job_result(ctx, result), Err(err) => ctx.errors.push(GraphQLError { @@ -179,23 +168,29 @@ impl<'exec> Executor<'exec> { } } - async fn execute_sequence_wave(&self, ctx: &mut ExecutionContext<'exec>, node: &SequenceNode) { + async fn execute_sequence_wave( + &self, + ctx: &mut QueryPlanExecutionContext<'exec>, + node: &SequenceNode, + ) { for child in &node.nodes { Box::pin(self.execute_plan_node(ctx, child)).await; } } - async fn execute_parallel_wave(&self, ctx: &mut ExecutionContext<'exec>, node: &ParallelNode) { - let mut scope = ConcurrencyScope::new(); + async fn execute_parallel_wave( + &self, + ctx: &mut QueryPlanExecutionContext<'exec>, + node: &ParallelNode, + ) { + let mut scope = FuturesUnordered::new(); for child in &node.nodes { let job_future = self.prepare_job_future(child, &ctx.final_response); - scope.spawn(job_future); + scope.push(job_future); } - let results = scope.join_all().await; - - for result in results { + while let Some(result) = scope.next().await { match result { Ok(job) => { self.process_job_result(ctx, job); @@ -210,7 +205,7 @@ impl<'exec> Executor<'exec> { } } - async fn execute_plan_node(&self, ctx: &mut ExecutionContext<'exec>, node: &PlanNode) { + async fn execute_plan_node(&self, ctx: &mut QueryPlanExecutionContext<'exec>, node: &PlanNode) { match node { PlanNode::Fetch(fetch_node) => match self.execute_fetch_node(fetch_node, None).await { Ok(job) => { @@ -234,7 +229,7 @@ impl<'exec> Executor<'exec> { flatten_node, Some(p.representations), Some(p.representation_hashes), - Some(p.representation_hash_to_index), + Some(p.filtered_representations_hashes), ) .await { @@ -290,7 +285,7 @@ impl<'exec> Executor<'exec> { flatten_node, Some(p.representations), Some(p.representation_hashes), - Some(p.representation_hash_to_index), + Some(p.filtered_representations_hashes), )), Ok(None) => Box::pin(async { Ok(ExecutionJob::None) }), Err(e) => Box::pin(async move { Err(e) }), @@ -309,10 +304,10 @@ impl<'exec> Executor<'exec> { fn process_subgraph_response( &self, - ctx: &mut ExecutionContext<'exec>, + ctx: &mut QueryPlanExecutionContext<'exec>, response_bytes: Bytes, fetch_node_id: i64, - ) -> Option<(Value<'exec>, Option<&'exec Vec>)> { + ) -> Option<(SubgraphResponse<'exec>, Option<&'exec Vec>)> { let idx = ctx.response_storage.add_response(response_bytes); // SAFETY: The `bytes` are transmuted to the lifetime `'a` of the `ExecutionContext`. // This is safe because the `response_storage` is part of the `ExecutionContext` (`ctx`) @@ -344,31 +339,37 @@ impl<'exec> Executor<'exec> { } }; - ctx.handle_errors(response.errors); - - Some((response.data, output_rewrites)) + Some((response, output_rewrites)) } - fn process_job_result(&self, ctx: &mut ExecutionContext<'exec>, job: ExecutionJob) { + fn process_job_result(&self, ctx: &mut QueryPlanExecutionContext<'exec>, job: ExecutionJob) { match job { ExecutionJob::Fetch(job) => { - if let Some((mut data, output_rewrites)) = + if let Some((mut response, output_rewrites)) = self.process_subgraph_response(ctx, job.response, job.fetch_node_id) { if let Some(output_rewrites) = output_rewrites { for output_rewrite in output_rewrites { - output_rewrite.rewrite(&self.schema_metadata.possible_types, &mut data); + output_rewrite + .rewrite(&self.schema_metadata.possible_types, &mut response.data); } } - deep_merge(&mut ctx.final_response, data); + deep_merge(&mut ctx.final_response, response.data); + + if let Some(errors) = response.errors { + for mut error in errors { + error = add_subgraph_info_to_error(error, &job.subgraph_name); + ctx.errors.push(error); + } + } } } - ExecutionJob::FlattenFetch(job) => { - if let Some((mut data, output_rewrites)) = + ExecutionJob::FlattenFetch(mut job) => { + if let Some((ref mut response, output_rewrites)) = self.process_subgraph_response(ctx, job.response, job.fetch_node_id) { - if let Some(mut entities) = data.take_entities() { + if let Some(mut entities) = response.data.take_entities() { if let Some(output_rewrites) = output_rewrites { for output_rewrite in output_rewrites { for entity in &mut entities { @@ -377,30 +378,109 @@ impl<'exec> Executor<'exec> { } } } - - let mut index = 0; let normalized_path = job.flatten_node_path.as_slice(); - traverse_and_callback_mut( - &mut ctx.final_response, - normalized_path, - self.schema_metadata, - &mut |target| { - let hash = job.representation_hashes[index]; - if let Some(entity_index) = - job.representation_hash_to_index.get(&hash) - { - if let Some(entity) = entities.get(*entity_index) { - // SAFETY: `new_val` is a clone of an entity that lives for `'a`. - // The transmute is to satisfy the compiler, but the lifetime - // is valid. - let new_val: Value<'_> = - unsafe { std::mem::transmute(entity.clone()) }; - deep_merge(target, new_val); + 'entity_loop: for (entity, hash) in entities + .into_iter() + .zip(job.representation_hashes.iter_mut()) + { + if let Some(target_paths) = + job.filtered_representations_hashes.get_mut(hash) + { + for indexes_in_path in target_paths { + let mut target: &mut Value<'exec> = &mut ctx.final_response; + for path_segment in normalized_path { + match path_segment { + FlattenNodePathSegment::List => { + let index = indexes_in_path.pop_front().unwrap(); + if let Value::Array(arr) = target { + if let Some(item) = arr.get_mut(index) { + target = item; + } else { + continue 'entity_loop; // Skip if index is out of bounds + } + } else { + continue 'entity_loop; // Skip if target is not an array + } + } + FlattenNodePathSegment::Field(field_name) => { + if let Value::Object(map) = target { + if let Ok(idx) = map.binary_search_by_key( + &field_name.as_str(), + |(k, _)| k, + ) { + if let Some((_, value)) = map.get_mut(idx) { + target = value; + } else { + continue 'entity_loop; + // Skip if field not found + } + } else { + continue 'entity_loop; // Skip if field not found + } + } else { + continue 'entity_loop; // Skip if target is not an object + } + } + FlattenNodePathSegment::Cast(type_condition) => { + if let Some(map) = target.as_object() { + if let Ok(idx) = map.binary_search_by_key( + &TYPENAME_FIELD_NAME, + |(k, _)| k, + ) { + if let Some((_, type_name)) = map.get(idx) { + if let Some(type_name) = + type_name.as_str() + { + if !self + .schema_metadata + .possible_types + .entity_satisfies_type_condition( + type_name, + type_condition, + ) + { + continue 'entity_loop; // Skip if type condition is not satisfied + } + } + } + } + } + } + } + } + if !indexes_in_path.is_empty() { + // If there are still indexes left, we need to traverse them + while let Some(index) = indexes_in_path.pop_front() { + if let Value::Array(arr) = target { + if let Some(item) = arr.get_mut(index) { + target = item; + } else { + continue 'entity_loop; // Skip if index is out of bounds + } + } else { + continue 'entity_loop; // Skip if target is not an array + } + } } + let new_val: Value<'_> = + unsafe { std::mem::transmute(entity.clone()) }; + deep_merge(target, new_val); } - index += 1; - }, - ); + } + } + } + + if let Some(errors) = &response.errors { + for error in errors { + let normalized_errors = normalize_error_for_representation( + error, + &job.subgraph_name, + job.flatten_node_path.as_slice(), + &job.representation_hashes, + &job.filtered_representations_hashes, + ); + ctx.errors.extend(normalized_errors); + } } } } @@ -424,55 +504,68 @@ impl<'exec> Executor<'exec> { None => return Ok(None), }; - let mut index = 0; let normalized_path = flatten_node.path.as_slice(); let mut filtered_representations = BytesMut::new(); filtered_representations.put(OPEN_BRACKET); let proj_ctx = RequestProjectionContext::new(&self.schema_metadata.possible_types); let mut representation_hashes: Vec = Vec::new(); - let mut filtered_representations_hashes: HashMap = HashMap::new(); + let mut filtered_representations_hashes: HashMap>> = + HashMap::new(); let arena = bumpalo::Bump::new(); - + let mut number_of_indexes = 0; + for segment in normalized_path.iter() { + if *segment == FlattenNodePathSegment::List { + number_of_indexes += 1; + } + } traverse_and_callback( final_response, normalized_path, self.schema_metadata, - &mut |entity| { - let hash = entity.to_hash(&requires_nodes.items, proj_ctx.possible_types); - - if !entity.is_null() { - representation_hashes.push(hash); + VecDeque::with_capacity(number_of_indexes), + &mut |entity: &Value, + indexes_in_path: VecDeque| + -> Result<(), PlanExecutionError> { + if entity.is_null() { + return Ok(()); } - if filtered_representations_hashes.contains_key(&hash) { - return Ok::<(), PlanExecutionError>(()); - } + let hash = entity.to_hash(&requires_nodes.items, proj_ctx.possible_types); - let entity = if let Some(input_rewrites) = &fetch_node.input_rewrites { - let new_entity = arena.alloc(entity.clone()); - for input_rewrite in input_rewrites { - input_rewrite.rewrite(&self.schema_metadata.possible_types, new_entity); + let indexes_in_paths = filtered_representations_hashes.get_mut(&hash); + + match indexes_in_paths { + Some(indexes_in_paths) => { + indexes_in_paths.push(indexes_in_path); + } + None => { + let entity = if let Some(input_rewrites) = &fetch_node.input_rewrites { + let new_entity = arena.alloc(entity.clone()); + for input_rewrite in input_rewrites { + input_rewrite + .rewrite(&self.schema_metadata.possible_types, new_entity); + } + new_entity + } else { + entity + }; + + let is_projected = project_requires( + &proj_ctx, + &requires_nodes.items, + entity, + &mut filtered_representations, + filtered_representations_hashes.is_empty(), + None, + )?; + + if is_projected { + representation_hashes.push(hash); + filtered_representations_hashes.insert(hash, vec![indexes_in_path]); + } } - new_entity - } else { - entity - }; - - let is_projected = project_requires( - &proj_ctx, - &requires_nodes.items, - entity, - &mut filtered_representations, - filtered_representations_hashes.is_empty(), - None, - )?; - - if is_projected { - filtered_representations_hashes.insert(hash, index); } - index += 1; - Ok(()) }, )?; @@ -485,7 +578,7 @@ impl<'exec> Executor<'exec> { Ok(Some(PreparedFlattenData { representations: filtered_representations, representation_hashes, - representation_hash_to_index: filtered_representations_hashes, + filtered_representations_hashes, })) } @@ -494,18 +587,20 @@ impl<'exec> Executor<'exec> { node: &FlattenNode, representations: Option, representation_hashes: Option>, - filtered_representations_hashes: Option>, + filtered_representations_hashes: Option>>>, ) -> Result { Ok(match node.node.as_ref() { PlanNode::Fetch(fetch_node) => ExecutionJob::FlattenFetch(FlattenFetchJob { flatten_node_path: node.path.clone(), + subgraph_name: fetch_node.service_name.to_string(), response: self .execute_fetch_node(fetch_node, representations) .await? .into(), fetch_node_id: fetch_node.id, representation_hashes: representation_hashes.unwrap_or_default(), - representation_hash_to_index: filtered_representations_hashes.unwrap_or_default(), + filtered_representations_hashes: filtered_representations_hashes + .unwrap_or_default(), }), _ => ExecutionJob::None, }) @@ -517,12 +612,13 @@ impl<'exec> Executor<'exec> { representations: Option, ) -> Result { Ok(ExecutionJob::Fetch(FetchJob { + subgraph_name: node.service_name.to_string(), fetch_node_id: node.id, response: self - .executors + .subgraph_executors .execute( &node.service_name, - HttpExecutionRequest { + SubgraphExecutionRequest { query: node.operation.document_str.as_str(), operation_name: node.operation_name.as_deref(), variables: None, diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index a2d50521..22d78c76 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -5,7 +5,7 @@ use bytes::{Bytes, BytesMut}; #[async_trait] pub trait SubgraphExecutor { - async fn execute<'a>(&self, execution_request: HttpExecutionRequest<'a>) -> Bytes; + async fn execute<'a>(&self, execution_request: SubgraphExecutionRequest<'a>) -> Bytes; fn to_boxed_arc<'a>(self) -> Arc> where Self: Sized + Send + Sync + 'a, @@ -18,7 +18,7 @@ pub type SubgraphExecutorType = dyn crate::executors::common::SubgraphExecutor + pub type SubgraphExecutorBoxedArc = Arc>; -pub struct HttpExecutionRequest<'a> { +pub struct SubgraphExecutionRequest<'a> { pub query: &'a str, pub operation_name: Option<&'a str>, // TODO: variables could be stringified before even executing the request diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index af952c6e..148df69e 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -10,7 +10,7 @@ use http_body_util::Full; use hyper::{body::Bytes, Version}; use hyper_util::client::legacy::{connect::HttpConnector, Client}; -use crate::executors::common::HttpExecutionRequest; +use crate::executors::common::SubgraphExecutionRequest; use crate::executors::error::SubgraphExecutorError; use crate::response::graphql_error::GraphQLError; use crate::utils::consts::CLOSE_BRACE; @@ -51,7 +51,7 @@ impl HTTPSubgraphExecutor { async fn _execute<'a>( &self, - execution_request: HttpExecutionRequest<'a>, + execution_request: SubgraphExecutionRequest<'a>, ) -> Result { // We may want to remove it, but let's see. let mut body = BytesMut::with_capacity(4096); @@ -123,7 +123,7 @@ impl HTTPSubgraphExecutor { #[async_trait] impl SubgraphExecutor for HTTPSubgraphExecutor { - async fn execute<'a>(&self, execution_request: HttpExecutionRequest<'a>) -> Bytes { + async fn execute<'a>(&self, execution_request: SubgraphExecutionRequest<'a>) -> Bytes { match self._execute(execution_request).await { Ok(bytes) => bytes, Err(e) => { diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 01c1b358..ece5b23f 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -8,7 +8,7 @@ use hyper_util::{ use crate::{ executors::{ - common::{HttpExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc}, + common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc}, error::SubgraphExecutorError, http::HTTPSubgraphExecutor, }, @@ -35,7 +35,7 @@ impl SubgraphExecutorMap { pub async fn execute<'a>( &self, subgraph_name: &str, - execution_request: HttpExecutionRequest<'a>, + execution_request: SubgraphExecutionRequest<'a>, ) -> Bytes { match self.inner.get(subgraph_name) { Some(executor) => executor.execute(execution_request).await, diff --git a/lib/executor/src/lib.rs b/lib/executor/src/lib.rs index 247fb077..079ef9f8 100644 --- a/lib/executor/src/lib.rs +++ b/lib/executor/src/lib.rs @@ -10,3 +10,6 @@ pub mod variables; pub use execution::plan::execute_query_plan; pub use executors::map::SubgraphExecutorMap; + +#[cfg(test)] +mod tests; diff --git a/lib/executor/src/response/error_normalization.rs b/lib/executor/src/response/error_normalization.rs new file mode 100644 index 00000000..588cf61b --- /dev/null +++ b/lib/executor/src/response/error_normalization.rs @@ -0,0 +1,186 @@ +use std::collections::{HashMap, VecDeque}; + +use query_planner::planner::plan_nodes::FlattenNodePathSegment; + +use crate::response::graphql_error::{GraphQLError, GraphQLErrorPathSegment}; + +/** + * Map `[_entities, 0, field]` to `["actual_field", "field"]`; + * + * For example if the error location is `[_entities, 0, name]` + * and flatten path is ['product', 'reviews', 0, 'author'] + * it becomes `["product", "reviews", "0", "author", "name"]`. + */ +pub fn normalize_error_for_representation( + error: &GraphQLError, + subgraph_name: &str, + normalized_path: &[FlattenNodePathSegment], + representation_hashes: &[u64], + hashes_to_indexes: &HashMap>>, +) -> Vec { + let mut new_errors: Vec = Vec::new(); + if let Some(path_in_error) = &error.path { + if let Some(GraphQLErrorPathSegment::String(first_path)) = path_in_error.first() { + if first_path == "_entities" { + if let Some(GraphQLErrorPathSegment::Index(entity_index)) = path_in_error.get(1) { + if let Some(representation_hash) = representation_hashes.get(*entity_index) { + if let Some(indexes_in_paths) = hashes_to_indexes.get(representation_hash) { + for indexes_in_path in indexes_in_paths { + let mut indexes_in_path = indexes_in_path.clone(); + let mut real_path: Vec = + Vec::with_capacity( + normalized_path.len() + path_in_error.len() - 2, + ); + for segment in normalized_path { + match segment { + FlattenNodePathSegment::Field(field_name) => { + real_path.push(GraphQLErrorPathSegment::String( + field_name.to_string(), + )); + } + FlattenNodePathSegment::List => { + if let Some(index_in_path) = indexes_in_path.pop_front() + { + real_path.push(GraphQLErrorPathSegment::Index( + index_in_path, + )); + } + } + FlattenNodePathSegment::Cast(_type_condition) => { + // Cast segments are not included in the error path + continue; + } + } + } + if !indexes_in_path.is_empty() { + // If there are still indexes left, we need to traverse them + while let Some(index) = indexes_in_path.pop_front() { + real_path.push(GraphQLErrorPathSegment::Index(index)); + } + } + real_path.extend_from_slice(&path_in_error[2..]); + let mut new_error = error.clone(); + if !real_path.is_empty() { + new_error.path = Some(real_path); + } + new_error = add_subgraph_info_to_error(new_error, subgraph_name); + new_errors.push(new_error); + } + return new_errors; + } + } + } + } + } + } + // Use the path without indexes in case of unlocated error + let mut real_path: Vec = Vec::with_capacity(normalized_path.len()); + for segment in normalized_path { + match segment { + FlattenNodePathSegment::Field(field_name) => { + real_path.push(GraphQLErrorPathSegment::String(field_name.to_string())); + } + FlattenNodePathSegment::List => { + break; + } + FlattenNodePathSegment::Cast(_type_condition) => { + // Cast segments are not included in the error path + continue; + } + } + } + let mut new_error = error.clone(); + if !real_path.is_empty() { + new_error.path = Some(real_path); + } + new_error = add_subgraph_info_to_error(new_error, subgraph_name); + new_errors.push(new_error); + new_errors +} + +pub fn add_subgraph_info_to_error(mut error: GraphQLError, subgraph_name: &str) -> GraphQLError { + let mut extensions = error.extensions.unwrap_or_default(); + if !extensions.contains_key("serviceName") { + extensions.insert("serviceName".to_string(), subgraph_name.into()); + } + if !extensions.contains_key("code") { + extensions.insert("code".to_string(), "DOWNSTREAM_SERVICE_ERROR".into()); + } + error.extensions = Some(extensions); + error +} + +#[test] +fn test_normalize_errors_for_representations() { + // "products", "@", "reviews", "@", "author" + let normalized_path = vec![ + FlattenNodePathSegment::Field("products".into()), + FlattenNodePathSegment::List, + FlattenNodePathSegment::Field("reviews".into()), + FlattenNodePathSegment::List, + FlattenNodePathSegment::Field("author".into()), + ]; + let mut indexes_in_paths: HashMap>> = HashMap::new(); + indexes_in_paths.insert(0, vec![VecDeque::from(vec![0, 0])]); + indexes_in_paths.insert(1, vec![VecDeque::from(vec![0, 1])]); + indexes_in_paths.insert(2, vec![VecDeque::from(vec![1, 1])]); + indexes_in_paths.insert(3, vec![VecDeque::from(vec![1, 2])]); + let representation_hashes: Vec = vec![0, 1, 2, 3]; + let errors: Vec = vec![ + GraphQLError { + message: "Error 1".to_string(), + locations: None, + path: Some(vec![ + GraphQLErrorPathSegment::String("_entities".to_string()), + GraphQLErrorPathSegment::Index(3), + GraphQLErrorPathSegment::String("name".to_string()), + ]), + extensions: None, + }, + GraphQLError { + message: "Error 2".to_string(), + locations: None, + path: Some(vec![ + GraphQLErrorPathSegment::String("_entities".to_string()), + GraphQLErrorPathSegment::Index(2), + GraphQLErrorPathSegment::String("age".to_string()), + ]), + extensions: None, + }, + ]; + let mut normalized_errors = Vec::new(); + for error in &errors { + let normalized_error = normalize_error_for_representation( + error, + "products", + &normalized_path, + &representation_hashes, + &indexes_in_paths, + ); + normalized_errors.extend(normalized_error); + } + println!("{:?}", normalized_errors); + assert_eq!(normalized_errors.len(), 2); + assert_eq!( + normalized_errors[0].path, + Some(vec![ + GraphQLErrorPathSegment::String("products".to_string()), + GraphQLErrorPathSegment::Index(1), + GraphQLErrorPathSegment::String("reviews".to_string()), + GraphQLErrorPathSegment::Index(2), + GraphQLErrorPathSegment::String("author".to_string()), + GraphQLErrorPathSegment::String("name".to_string()), + ]) + ); + assert_eq!( + normalized_errors[1].path, + Some(vec![ + GraphQLErrorPathSegment::String("products".to_string()), + GraphQLErrorPathSegment::Index(1), + GraphQLErrorPathSegment::String("reviews".to_string()), + GraphQLErrorPathSegment::Index(1), + GraphQLErrorPathSegment::String("author".to_string()), + GraphQLErrorPathSegment::String("age".to_string()), + ]) + ); +} diff --git a/lib/executor/src/response/graphql_error.rs b/lib/executor/src/response/graphql_error.rs index 88b9c6e8..1ac0c642 100644 --- a/lib/executor/src/response/graphql_error.rs +++ b/lib/executor/src/response/graphql_error.rs @@ -2,7 +2,7 @@ use graphql_parser::Pos; use graphql_tools::validation::utils::ValidationError; use serde::{de, Deserialize, Deserializer, Serialize}; use sonic_rs::Value; -use std::fmt; +use std::{collections::HashMap, fmt}; #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] @@ -12,7 +12,7 @@ pub struct GraphQLError { pub locations: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] pub path: Option>, - pub extensions: Option, + pub extensions: Option>, } impl From for GraphQLError { @@ -52,7 +52,7 @@ pub struct GraphQLErrorLocation { pub column: usize, } -#[derive(Clone, Debug, Serialize)] +#[derive(Clone, Debug, Serialize, PartialEq)] pub enum GraphQLErrorPathSegment { String(String), Index(usize), diff --git a/lib/executor/src/response/mod.rs b/lib/executor/src/response/mod.rs index efd86056..5d134e9e 100644 --- a/lib/executor/src/response/mod.rs +++ b/lib/executor/src/response/mod.rs @@ -1,3 +1,4 @@ +pub mod error_normalization; pub mod graphql_error; pub mod merge; pub mod storage; diff --git a/lib/executor/src/response/value.rs b/lib/executor/src/response/value.rs index c8691ace..f27b045f 100644 --- a/lib/executor/src/response/value.rs +++ b/lib/executor/src/response/value.rs @@ -183,6 +183,13 @@ impl<'a> Value<'a> { } } + pub fn as_array(&self) -> Option<&Vec>> { + match self { + Value::Array(arr) => Some(arr), + _ => None, + } + } + pub fn is_null(&self) -> bool { matches!(self, Value::Null) } diff --git a/lib/executor/src/tests/async_graphql.rs b/lib/executor/src/tests/async_graphql.rs new file mode 100644 index 00000000..bf18a4ea --- /dev/null +++ b/lib/executor/src/tests/async_graphql.rs @@ -0,0 +1,86 @@ +use std::collections::HashMap; + +use async_trait::async_trait; +use bytes::Bytes; +use sonic_rs::JsonContainerTrait; + +use crate::{ + executors::common::{SubgraphExecutionRequest, SubgraphExecutor}, + response::graphql_error::{GraphQLError, GraphQLErrorLocation, GraphQLErrorPathSegment}, +}; + +#[async_trait] +impl SubgraphExecutor for Executor +where + Executor: async_graphql::Executor, +{ + async fn execute<'a>(&self, execution_request: SubgraphExecutionRequest<'a>) -> Bytes { + let response: async_graphql::Response = self.execute(execution_request.into()).await; + serde_json::to_vec(&response).unwrap().into() + } +} + +impl<'a> From> for async_graphql::Request { + fn from(exec_request: SubgraphExecutionRequest) -> Self { + let mut req = async_graphql::Request::new(exec_request.query); + if let Some(variables) = exec_request.variables { + req = req.variables(async_graphql::Variables::from_json(serde_json::json!( + variables + ))); + } + if let Some(representations) = exec_request.representations { + req.variables.insert( + async_graphql::Name::new("representations"), + async_graphql::Value::from_json( + serde_json::from_slice(&representations).unwrap_or_default(), + ) + .unwrap(), + ); + } + if let Some(operation_name) = exec_request.operation_name { + req = req.operation_name(operation_name); + } + req + } +} + +impl From<&async_graphql::ServerError> for GraphQLError { + fn from(error: &async_graphql::ServerError) -> Self { + GraphQLError { + message: error.message.to_string(), + locations: Some( + error + .locations + .iter() + .map(|loc| GraphQLErrorLocation { + line: loc.line, + column: loc.column, + }) + .collect(), + ), + path: Some( + error + .path + .iter() + .map(|s| match s { + async_graphql::PathSegment::Field(name) => { + GraphQLErrorPathSegment::String(name.to_string()) + } + async_graphql::PathSegment::Index(index) => { + GraphQLErrorPathSegment::Index((*index).into()) + } + }) + .collect(), + ), + extensions: error.extensions.as_ref().map(|ext| { + let serialized = sonic_rs::json!(ext); + serialized + .as_object() + .unwrap() + .iter() + .map(|(k, v)| (k.to_string(), v.clone())) + .collect::>() + }), + } + } +} diff --git a/lib/executor/src/tests/fixtures/error_propagation/directors.rs b/lib/executor/src/tests/fixtures/error_propagation/directors.rs new file mode 100644 index 00000000..29a8efe7 --- /dev/null +++ b/lib/executor/src/tests/fixtures/error_propagation/directors.rs @@ -0,0 +1,43 @@ +use async_graphql::{EmptyMutation, EmptySubscription, Object, Schema, SimpleObject, ID}; + +#[derive(SimpleObject, Clone)] +#[graphql(extends)] +pub struct Movie { + #[graphql(external)] + id: String, + director: Director, +} +pub struct Query; + +#[derive(SimpleObject, Clone)] +pub struct Director { + id: String, + name: String, +} + +#[Object(extends = true)] +impl Query { + #[graphql(entity)] + async fn movie_with_director(&self, id: ID) -> Result { + // Throw on purpose + if id == ID("2".to_string()) { + Err(async_graphql::Error::new( + "Director not found for movie with id 2", + )) + } else { + Ok(Movie { + id: id.to_string(), + director: Director { + id: "1".to_string(), + name: "Christopher Nolan".to_string(), + }, + }) + } + } +} + +pub fn get_subgraph() -> Schema { + Schema::build(Query, EmptyMutation, EmptySubscription) + .enable_federation() + .finish() +} diff --git a/lib/executor/src/tests/fixtures/error_propagation/mod.rs b/lib/executor/src/tests/fixtures/error_propagation/mod.rs new file mode 100644 index 00000000..a820ea31 --- /dev/null +++ b/lib/executor/src/tests/fixtures/error_propagation/mod.rs @@ -0,0 +1,2 @@ +pub mod directors; +pub mod movies; diff --git a/lib/executor/src/tests/fixtures/error_propagation/movies.rs b/lib/executor/src/tests/fixtures/error_propagation/movies.rs new file mode 100644 index 00000000..b178c1a2 --- /dev/null +++ b/lib/executor/src/tests/fixtures/error_propagation/movies.rs @@ -0,0 +1,32 @@ +use async_graphql::{EmptyMutation, EmptySubscription, Object, Schema, SimpleObject, ID}; + +#[derive(SimpleObject, Clone)] +pub struct Movie { + id: ID, + name: String, +} +pub struct Query; + +#[Object(extends = true)] +impl Query { + async fn movie(&self, id: ID) -> Movie { + // Simulate a movie fetch + if id == ID("1".to_string()) { + Movie { + id: ID("1".to_string()), + name: "Inception".to_string(), + } + } else { + Movie { + id: ID("2".to_string()), + name: "Interstellar".to_string(), + } + } + } +} + +pub fn get_subgraph() -> Schema { + Schema::build(Query, EmptyMutation, EmptySubscription) + .enable_federation() + .finish() +} diff --git a/lib/executor/src/tests/fixtures/error_propagation/operation.graphql b/lib/executor/src/tests/fixtures/error_propagation/operation.graphql new file mode 100644 index 00000000..e7e6f131 --- /dev/null +++ b/lib/executor/src/tests/fixtures/error_propagation/operation.graphql @@ -0,0 +1,17 @@ +fragment Movie on Movie { + id + name + director { + id + name + } +} + +query { + movie1: movie(id: "1") { + ...Movie + } + movie2: movie(id: "2") { + ...Movie + } +} diff --git a/lib/executor/src/tests/fixtures/error_propagation/supergraph.graphql b/lib/executor/src/tests/fixtures/error_propagation/supergraph.graphql new file mode 100644 index 00000000..c206c53e --- /dev/null +++ b/lib/executor/src/tests/fixtures/error_propagation/supergraph.graphql @@ -0,0 +1,84 @@ +schema + @link(url: "https://specs.apollo.dev/link/v1.0") + @link(url: "https://specs.apollo.dev/join/v0.3", for: EXECUTION) { + query: Query +} + +directive @join__enumValue(graph: join__Graph!) repeatable on ENUM_VALUE + +directive @join__graph(name: String!, url: String!) on ENUM_VALUE + +directive @join__field( + graph: join__Graph + requires: join__FieldSet + provides: join__FieldSet + type: String + external: Boolean + override: String + usedOverridden: Boolean +) repeatable on FIELD_DEFINITION | INPUT_FIELD_DEFINITION + +directive @join__implements( + graph: join__Graph! + interface: String! +) repeatable on OBJECT | INTERFACE + +directive @join__type( + graph: join__Graph! + key: join__FieldSet + extension: Boolean! = false + resolvable: Boolean! = true + isInterfaceObject: Boolean! = false +) repeatable on OBJECT | INTERFACE | UNION | ENUM | INPUT_OBJECT | SCALAR + +directive @join__unionMember( + graph: join__Graph! + member: String! +) repeatable on UNION + +scalar join__FieldSet + +directive @link( + url: String + as: String + for: link__Purpose + import: [link__Import] +) repeatable on SCHEMA + +scalar link__Import + +enum link__Purpose { + """ + `SECURITY` features provide metadata necessary to securely resolve fields. + """ + SECURITY + + """ + `EXECUTION` features provide metadata necessary for operation execution. + """ + EXECUTION +} + +enum join__Graph { + DIRECTORS @join__graph(name: "directors", url: "") + MOVIES @join__graph(name: "movies", url: "") +} + +type Director @join__type(graph: DIRECTORS) { + id: ID! + name: String! +} + +type Movie + @join__type(graph: DIRECTORS, key: "id") + @join__type(graph: MOVIES, key: "id") { + id: ID! + @join__field(graph: DIRECTORS, external: true) + @join__field(graph: MOVIES) + director: Director @join__field(graph: DIRECTORS) + name: String! @join__field(graph: MOVIES) +} + +type Query @join__type(graph: DIRECTORS) @join__type(graph: MOVIES) { + movie: Movie! @join__field(graph: MOVIES) +} diff --git a/lib/executor/src/tests/fixtures/mod.rs b/lib/executor/src/tests/fixtures/mod.rs new file mode 100644 index 00000000..3c491635 --- /dev/null +++ b/lib/executor/src/tests/fixtures/mod.rs @@ -0,0 +1 @@ +pub mod error_propagation; diff --git a/lib/executor/src/tests/mod.rs b/lib/executor/src/tests/mod.rs new file mode 100644 index 00000000..cc9592dd --- /dev/null +++ b/lib/executor/src/tests/mod.rs @@ -0,0 +1,83 @@ +use query_planner::graph::PlannerOverrideContext; +use sonic_rs::JsonValueTrait; + +use crate::response::graphql_error::GraphQLErrorPathSegment; +use crate::{ + context::QueryPlanExecutionContext, execution::plan::QueryPlanExecutor, + executors::common::SubgraphExecutor, introspection::schema::SchemaWithMetadata, + SubgraphExecutorMap, +}; + +mod async_graphql; +mod fixtures; + +#[test] +fn error_propagation() { + let supergraph_sdl = + std::fs::read_to_string("./src/tests/fixtures/error_propagation/supergraph.graphql") + .expect("Unable to read input file"); + let parsed_schema = query_planner::utils::parsing::parse_schema(&supergraph_sdl); + let planner = query_planner::planner::Planner::new_from_supergraph(&parsed_schema) + .expect("Failed to create planner from supergraph"); + let parsed_document = query_planner::utils::parsing::parse_operation( + &std::fs::read_to_string("./src/tests/fixtures/error_propagation/operation.graphql") + .expect("Unable to read input file"), + ); + let normalized_document = query_planner::ast::normalization::normalize_operation( + &planner.supergraph, + &parsed_document, + None, + ) + .expect("Failed to normalize operation"); + let normalized_operation = normalized_document.executable_operation(); + let query_plan = planner + .plan_from_normalized_operation(normalized_operation, PlannerOverrideContext::default()) + .expect("Failed to create query plan"); + + let schema_metadata = SchemaWithMetadata::schema_metadata(&planner.consumer_schema); + let movies_subgraph = fixtures::error_propagation::movies::get_subgraph(); + let directors_subgraph = fixtures::error_propagation::directors::get_subgraph(); + let mut subgraph_executor_map = SubgraphExecutorMap::new(); + subgraph_executor_map.insert_boxed_arc( + "movies".to_string(), + SubgraphExecutor::to_boxed_arc(movies_subgraph), + ); + subgraph_executor_map.insert_boxed_arc( + "directors".to_string(), + SubgraphExecutor::to_boxed_arc(directors_subgraph), + ); + tokio_test::block_on(async { + let qp_executor = QueryPlanExecutor::new(&None, &subgraph_executor_map, &schema_metadata); + let mut qp_exec_ctx = + QueryPlanExecutionContext::new(&query_plan, crate::response::value::Value::Null); + qp_executor + .execute(&mut qp_exec_ctx, query_plan.node.as_ref()) + .await; + assert_eq!(qp_exec_ctx.errors.len(), 1); + let error = &qp_exec_ctx.errors[0]; + assert_eq!( + error.path, + Some(vec![GraphQLErrorPathSegment::String("movie2".to_string())]) + ); + assert_eq!(error.message, "Director not found for movie with id 2"); + assert_eq!( + error + .extensions + .as_ref() + .map(|ext| ext.get("code").map(|v| v.as_str())) + .flatten() + .flatten(), + Some("DOWNSTREAM_SERVICE_ERROR") + ); + assert_eq!( + error + .extensions + .as_ref() + .map(|ext| ext.get("serviceName").map(|v| v.as_str())) + .flatten() + .flatten(), + Some("directors") + ); + insta::assert_snapshot!(qp_exec_ctx.final_response.to_string()); + }); +} diff --git a/lib/executor/src/tests/snapshots/executor__tests__error_propagation.snap b/lib/executor/src/tests/snapshots/executor__tests__error_propagation.snap new file mode 100644 index 00000000..63b156de --- /dev/null +++ b/lib/executor/src/tests/snapshots/executor__tests__error_propagation.snap @@ -0,0 +1,5 @@ +--- +source: lib/executor/src/tests/mod.rs +expression: qp_exec_ctx.final_response.to_string() +--- +{"movie1": {"__typename": "Movie", "director": {"id": "1", "name": "Christopher Nolan"}, "id": "1", "name": "Inception"}, "movie2": {"__typename": "Movie", "id": "2", "name": "Interstellar"}} diff --git a/lib/executor/src/utils/traverse.rs b/lib/executor/src/utils/traverse.rs index 3b80f22e..ab6ac7ec 100644 --- a/lib/executor/src/utils/traverse.rs +++ b/lib/executor/src/utils/traverse.rs @@ -1,3 +1,5 @@ +use std::collections::VecDeque; + use query_planner::planner::plan_nodes::FlattenNodePathSegment; use crate::{ @@ -5,94 +7,25 @@ use crate::{ utils::consts::TYPENAME_FIELD_NAME, }; -pub fn traverse_and_callback_mut<'a, Callback>( - current_data: &mut Value<'a>, - remaining_path: &[FlattenNodePathSegment], - schema_metadata: &SchemaMetadata, - callback: &mut Callback, -) where - Callback: FnMut(&mut Value), -{ - if remaining_path.is_empty() { - if let Value::Array(arr) = current_data { - // If the path is empty, we call the callback on each item in the array - // We iterate because we want the entity objects directly - for item in arr.iter_mut() { - callback(item); - } - } else { - // If the path is empty and current_data is not an array, just call the callback - callback(current_data); - } - return; - } - - match &remaining_path[0] { - FlattenNodePathSegment::List => { - // If the key is List, we expect current_data to be an array - if let Value::Array(arr) = current_data { - let rest_of_path = &remaining_path[1..]; - for item in arr.iter_mut() { - traverse_and_callback_mut(item, rest_of_path, schema_metadata, callback); - } - } - } - FlattenNodePathSegment::Field(field_name) => { - // If the key is Field, we expect current_data to be an object - if let Value::Object(map) = current_data { - if let Ok(idx) = map.binary_search_by_key(&field_name.as_str(), |(k, _)| k) { - let (_, next_data) = map.get_mut(idx).unwrap(); - let rest_of_path = &remaining_path[1..]; - traverse_and_callback_mut(next_data, rest_of_path, schema_metadata, callback); - } - } - } - FlattenNodePathSegment::Cast(type_condition) => { - // If the key is Cast, we expect current_data to be an object or an array - if let Value::Object(obj) = current_data { - let type_name = obj - .binary_search_by_key(&TYPENAME_FIELD_NAME, |(k, _)| k) - .ok() - .and_then(|idx| obj[idx].1.as_str()) - .unwrap_or(type_condition); - if schema_metadata - .possible_types - .entity_satisfies_type_condition(type_name, type_condition) - { - let rest_of_path = &remaining_path[1..]; - traverse_and_callback_mut( - current_data, - rest_of_path, - schema_metadata, - callback, - ); - } - } else if let Value::Array(arr) = current_data { - // If the current data is an array, we need to check each item - for item in arr.iter_mut() { - traverse_and_callback_mut(item, remaining_path, schema_metadata, callback); - } - } - } - } -} - pub fn traverse_and_callback<'a, E, Callback>( current_data: &'a Value<'a>, remaining_path: &'a [FlattenNodePathSegment], schema_metadata: &'a SchemaMetadata, + current_indexes: VecDeque, callback: &mut Callback, ) -> Result<(), E> where - Callback: FnMut(&'a Value<'a>) -> Result<(), E>, + Callback: FnMut(&'a Value<'a>, VecDeque) -> Result<(), E>, { if remaining_path.is_empty() { if let Value::Array(arr) = current_data { - for item in arr.iter() { - callback(item)?; + for (index, item) in arr.iter().enumerate() { + let mut new_indexes = current_indexes.clone(); + new_indexes.push_back(index); + callback(item, new_indexes)?; } } else { - callback(current_data)?; + callback(current_data, current_indexes)?; } return Ok(()); } @@ -101,8 +34,16 @@ where FlattenNodePathSegment::List => { if let Value::Array(arr) = current_data { let rest_of_path = &remaining_path[1..]; - for item in arr.iter() { - traverse_and_callback(item, rest_of_path, schema_metadata, callback)?; + for (index, item) in arr.iter().enumerate() { + let mut new_indexes = current_indexes.clone(); + new_indexes.push_back(index); + traverse_and_callback( + item, + rest_of_path, + schema_metadata, + new_indexes, + callback, + )?; } } } @@ -111,7 +52,13 @@ where if let Ok(idx) = map.binary_search_by_key(&field_name.as_str(), |(k, _)| k) { let (_, next_data) = &map[idx]; let rest_of_path = &remaining_path[1..]; - traverse_and_callback(next_data, rest_of_path, schema_metadata, callback)?; + traverse_and_callback( + next_data, + rest_of_path, + schema_metadata, + current_indexes, + callback, + )?; } } } @@ -127,11 +74,25 @@ where .entity_satisfies_type_condition(type_name, type_condition) { let rest_of_path = &remaining_path[1..]; - traverse_and_callback(current_data, rest_of_path, schema_metadata, callback)?; + traverse_and_callback( + current_data, + rest_of_path, + schema_metadata, + current_indexes, + callback, + )?; } } else if let Value::Array(arr) = current_data { - for item in arr.iter() { - traverse_and_callback(item, remaining_path, schema_metadata, callback)?; + for (index, item) in arr.iter().enumerate() { + let mut new_indexes = current_indexes.clone(); + new_indexes.push_back(index); + traverse_and_callback( + item, + remaining_path, + schema_metadata, + new_indexes, + callback, + )?; } } }