Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1930,10 +1930,12 @@ impl AedbInstance {
key: &[u8],
consistency: ConsistencyMode,
) -> Result<Option<KvEntry>, QueryError> {
assert!(
!self.require_authenticated_calls,
"kv_get_no_auth called in secure/authenticated mode"
);
if self.require_authenticated_calls {
return Err(QueryError::PermissionDenied {
permission: "kv_get_no_auth is unavailable in secure mode".into(),
scope: "anonymous".into(),
});
}
self.kv_get_unchecked(project_id, scope_id, key, consistency)
.await
}
Expand Down Expand Up @@ -1977,10 +1979,12 @@ impl AedbInstance {
limit: u64,
consistency: ConsistencyMode,
) -> Result<Vec<(Vec<u8>, KvEntry)>, QueryError> {
assert!(
!self.require_authenticated_calls,
"kv_scan_prefix_no_auth called in secure/authenticated mode"
);
if self.require_authenticated_calls {
return Err(QueryError::PermissionDenied {
permission: "kv_scan_prefix_no_auth is unavailable in secure mode".into(),
scope: "anonymous".into(),
});
}
self.kv_scan_prefix_unchecked(project_id, scope_id, prefix, limit, consistency)
.await
}
Expand Down
19 changes: 19 additions & 0 deletions src/lib_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2238,6 +2238,25 @@ async fn query_no_auth_in_secure_mode_returns_structured_error() {
assert!(matches!(err, QueryError::PermissionDenied { .. }));
}

#[tokio::test]
async fn kv_no_auth_apis_in_secure_mode_return_structured_error() {
let dir = tempdir().expect("temp");
let db = AedbInstance::open_production(AedbConfig::production([4u8; 32]), dir.path())
.expect("open secure");

let get_err = db
.kv_get_no_auth("p", "app", b"k", ConsistencyMode::AtLatest)
.await
.expect_err("secure mode should reject kv_get_no_auth");
assert!(matches!(get_err, QueryError::PermissionDenied { .. }));

let scan_err = db
.kv_scan_prefix_no_auth("p", "app", b"k", 10, ConsistencyMode::AtLatest)
.await
.expect_err("secure mode should reject kv_scan_prefix_no_auth");
assert!(matches!(scan_err, QueryError::PermissionDenied { .. }));
}

#[tokio::test]
async fn existence_and_introspection_apis_report_catalog_state() {
let dir = tempdir().expect("temp");
Expand Down
139 changes: 135 additions & 4 deletions src/query/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub fn execute_query_with_options(
.map(|c| c.page_size)
.unwrap_or(max_scan_rows.min(100))
});
let effective_page_size = page_size;
let effective_page_size = page_size.min(max_scan_rows);
if let Some(result) =
try_primary_key_point_query(schema, table, &query, &cursor_state, snapshot_seq)?
{
Expand Down Expand Up @@ -206,9 +206,11 @@ pub fn execute_query_with_options(
&& query.order_by.is_empty()
&& query.aggregates.is_empty()
&& query.having.is_none()
&& let Some(limit) = query.limit
{
root = Box::new(LimitOperator::new(root, limit.saturating_add(1)));
root = Box::new(LimitOperator::new(
root,
effective_page_size.saturating_add(1),
));
}
}
ExecutionStage::Filter => {
Expand All @@ -233,7 +235,7 @@ pub fn execute_query_with_options(
})
.collect::<Result<Vec<_>, _>>()?;
let top_k_limit = if cursor_state.is_none() {
query.limit.map(|limit| limit.saturating_add(1))
Some(effective_page_size.saturating_add(1))
} else {
None
};
Expand Down Expand Up @@ -568,6 +570,51 @@ fn execute_join_query(
rows.retain(|r| crate::query::operators::eval_compiled_expr_public(&compiled, r));
}

if !query.aggregates.is_empty() {
let group_by_idx = query
.group_by
.iter()
.map(|name| {
columns
.iter()
.position(|c| c == name)
.ok_or_else(|| QueryError::ColumnNotFound {
table: "join".into(),
column: name.clone(),
})
})
.collect::<Result<Vec<_>, _>>()?;
let agg_col_idx = query
.aggregates
.iter()
.map(|agg| aggregate_col_idx(agg, &columns))
.collect::<Result<Vec<_>, _>>()?;

let mut aggregate = AggregateOperator::new(
Box::new(ScanOperator::new(rows)),
query.aggregates.clone(),
group_by_idx,
agg_col_idx,
);
let mut aggregated_rows = Vec::new();
while let Some(row) = aggregate.next() {
aggregated_rows.push(row);
}
rows = aggregated_rows;
columns = query.group_by.clone();
columns.extend(query.aggregates.iter().map(aggregate_output_name));
}

if let Some(having) = &query.having {
if query.aggregates.is_empty() {
return Err(QueryError::InvalidQuery {
reason: "having requires aggregate or group_by".into(),
});
}
let compiled = compile_expr(having, &columns, "join")?;
rows.retain(|r| crate::query::operators::eval_compiled_expr_public(&compiled, r));
}

if !query.order_by.is_empty() {
let order_pairs: Vec<(usize, crate::query::plan::Order)> = query
.order_by
Expand Down Expand Up @@ -2025,6 +2072,29 @@ mod tests {
assert!(matches!(err, QueryError::InvalidQuery { .. }));
}

#[test]
fn non_join_page_size_is_capped_by_max_scan_rows() {
let (keyspace, catalog) = setup();
let snapshot = keyspace.snapshot();
let result = execute_query_with_options(
&snapshot,
&catalog,
"A",
"app",
Query::select(&["*"])
.from("users")
.order_by("id", Order::Asc)
.limit(50),
&QueryOptions::default(),
9,
10,
)
.expect("bounded page");
assert_eq!(result.rows.len(), 10);
assert!(result.cursor.is_some());
assert!(result.rows_examined <= 100);
}

#[test]
fn join_scan_bound_is_enforced_when_full_scan_not_allowed() {
let (keyspace, mut catalog) = setup();
Expand Down Expand Up @@ -2181,6 +2251,67 @@ mod tests {
assert_eq!(result.rows.len(), 50);
}

#[test]
fn join_aggregate_count_and_having_are_applied() {
let (mut keyspace, mut catalog) = setup();
catalog
.create_table(
"A",
"app",
"profiles",
vec![
ColumnDef {
name: "user_id".into(),
col_type: ColumnType::Integer,
nullable: false,
},
ColumnDef {
name: "country".into(),
col_type: ColumnType::Text,
nullable: false,
},
],
vec!["user_id".into()],
)
.expect("profiles table");
for i in 0..50 {
keyspace.upsert_row(
"A",
"app",
"profiles",
vec![Value::Integer(i)],
Row::from_values(vec![
Value::Integer(i),
Value::Text(if i % 2 == 0 { "US" } else { "CA" }.into()),
]),
1,
);
}
let snapshot = keyspace.snapshot();
let result = execute_query(
&snapshot,
&catalog,
"A",
"app",
Query::select(&["p.country", "count_star"])
.from("users")
.alias("u")
.inner_join("profiles", "u.id", "user_id")
.with_last_join_alias("p")
.group_by(&["p.country"])
.aggregate(Aggregate::Count)
.having(Expr::Gt("count_star".into(), Value::Integer(20)))
.order_by("count_star", Order::Desc)
.limit(10),
)
.expect("join aggregate query");

assert_eq!(result.rows.len(), 2);
for row in result.rows {
assert!(matches!(row.values[1], Value::Integer(25)));
}
}

#[test]
fn left_join_supports_global_table_reference() {
let (mut keyspace, mut catalog) = setup();
Expand Down