Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rust/crates/cloudsearch-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ pub struct SearchRequest {
pub from: Option<usize>,
pub size: Option<usize>,
pub sort: Option<SortSpec>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_after: Option<Vec<serde_json::Value>>,
pub aggs: Option<BTreeMap<String, AggregationRequest>>,
}

Expand Down Expand Up @@ -378,6 +380,8 @@ pub struct SearchHit {
pub score: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none", rename = "highlight")]
pub highlight: Option<std::collections::BTreeMap<String, Vec<String>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort_values: Option<Vec<serde_json::Value>>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
Expand Down
7 changes: 7 additions & 0 deletions rust/crates/cloudsearch-common/tests/round_trip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ fn test_search_request_all_fields() {
field: "status".to_string(),
}),
)])),
search_after: None,
});
}

Expand Down Expand Up @@ -713,12 +714,14 @@ fn test_search_response_with_hits_and_aggs() {
source: serde_json::json!({"title": "First"}),
score: None,
highlight: None,
sort_values: None,
},
SearchHit {
id: "doc2".to_string(),
source: serde_json::json!({"title": "Second"}),
score: None,
highlight: None,
sort_values: None,
},
],
},
Expand Down Expand Up @@ -756,6 +759,7 @@ fn test_hits_metadata() {
source: serde_json::json!({"x": 1}),
score: None,
highlight: None,
sort_values: None,
}],
});
}
Expand All @@ -769,6 +773,7 @@ fn test_search_hit() {
source: serde_json::json!({"nested": {"field": "value"}}),
score: None,
highlight: None,
sort_values: None,
});
}

Expand Down Expand Up @@ -1085,6 +1090,7 @@ fn test_search_hit_with_highlight() {
"message".to_string(),
vec!["<em>hello</em> world".to_string()],
)])),
sort_values: None,
});
}

Expand All @@ -1096,6 +1102,7 @@ fn test_skip_serializing_none_for_highlight() {
source: serde_json::json!({"x": 1}),
score: None,
highlight: None,
sort_values: None,
};
round_trip_static_str(&hit, |value| {
// highlight field should not appear in JSON when None
Expand Down
107 changes: 106 additions & 1 deletion rust/crates/cloudsearch-index/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,12 +810,14 @@ impl IndexHandle {
source: l.source.clone(),
score: None,
highlight: None,
sort_values: None,
};
let rh = SearchHit {
id: r.id.clone(),
source: r.source.clone(),
score: None,
highlight: None,
sort_values: None,
};
compare_hits(&lh, &rh, sort)
});
Expand All @@ -827,7 +829,19 @@ impl IndexHandle {
});
}

let from = request.from.unwrap_or(0).min(MAX_SEARCH_OFFSET);
let from = if let Some(cursor) = &request.search_after {
// For search_after, find position where doc > cursor
scored
.iter()
.position(|(score, doc)| {
let doc_sort_values = compute_sort_values(doc, request.sort.as_ref(), *score);
compare_sort_values_list(&doc_sort_values, cursor, request.sort.as_ref())
== std::cmp::Ordering::Greater
})
.unwrap_or(scored.len())
} else {
request.from.unwrap_or(0).min(MAX_SEARCH_OFFSET)
};
let size = request.size.unwrap_or(total).min(MAX_SEARCH_SIZE);

let hits = scored
Expand All @@ -840,11 +854,13 @@ impl IndexHandle {
r if !r.is_empty() => extract_highlight(doc, doc_id_hash, r, query),
_ => None,
};
let sort_values = compute_sort_values(doc, request.sort.as_ref(), score);
SearchHit {
id: doc.id.clone(),
source: doc.source.clone(),
score: Some(score),
highlight,
sort_values: Some(sort_values),
}
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -933,6 +949,18 @@ impl IndexHandle {
)));
}

if request.search_after.is_some() && request.from.is_some() {
return Err(CloudSearchError::InvalidSearchRequest(
"search_after and from cannot be used together".to_string(),
));
}

if request.search_after.is_some() && request.sort.is_none() {
return Err(CloudSearchError::InvalidSearchRequest(
"search_after requires sort field to be specified".to_string(),
));
}

if let Some(aggs) = &request.aggs {
for (name, agg) in aggs {
match agg {
Expand Down Expand Up @@ -2299,6 +2327,83 @@ fn comparable_value(value: &serde_json::Value) -> Option<ComparableValue> {
None
}

/// Compute sort values for a document.
/// Returns a vector: [`sort_field_value`, `tie_breaker`].
/// Tie-breaker is [`score`, `doc_id`].
/// This is used for both `search_after` cursor positioning and response `sort_values`.
fn compute_sort_values(
doc: &IndexDocument,
sort: Option<&SortSpec>,
score: f32,
) -> Vec<serde_json::Value> {
let mut values = Vec::with_capacity(2);
if let Some(sort_spec) = sort {
if let Some(field_value) = doc.source.get(&sort_spec.field) {
values.push(field_value.clone());
} else {
values.push(serde_json::Value::Null);
}
}
// Tie-breaker: include score and doc_id to ensure uniqueness
values.push(serde_json::Value::Number(
serde_json::Number::from_f64(f64::from(score)).unwrap_or(serde_json::Number::from(0)),
));
values.push(serde_json::Value::String(doc.id.clone()));
values
}

/// Compare a document's sort values against a `search_after` cursor.
/// Returns `Ordering::Greater` when doc should come AFTER the cursor
/// (i.e., cursor is smaller than or equal to doc's values).
fn compare_sort_values_list(
doc_values: &[serde_json::Value],
cursor: &[serde_json::Value],
sort: Option<&SortSpec>,
) -> std::cmp::Ordering {
for (i, cursor_val) in cursor.iter().enumerate() {
if i >= doc_values.len() {
return std::cmp::Ordering::Less;
}
let doc_val = &doc_values[i];
let ordering = compare_json_values(doc_val, cursor_val, sort);
if ordering != std::cmp::Ordering::Equal {
return ordering;
}
}
std::cmp::Ordering::Equal
}

fn compare_json_values(
left: &serde_json::Value,
right: &serde_json::Value,
sort: Option<&SortSpec>,
) -> std::cmp::Ordering {
let left_comp = comparable_value(left);
let right_comp = comparable_value(right);

let ordering = match (left_comp, right_comp) {
(Some(l), Some(r)) => match (&l, &r) {
(ComparableValue::Number(ln), ComparableValue::Number(rn)) => ln.total_cmp(rn),
(ComparableValue::Timestamp(lt), ComparableValue::Timestamp(rt)) => lt.cmp(rt),
(ComparableValue::String(ls), ComparableValue::String(rs)) => ls.cmp(rs),
(ComparableValue::Boolean(lb), ComparableValue::Boolean(rb)) => lb.cmp(rb),
_ => std::cmp::Ordering::Equal,
},
(None, Some(_)) => std::cmp::Ordering::Greater,
(Some(_), None) => std::cmp::Ordering::Less,
(None, None) => std::cmp::Ordering::Equal,
};

if let Some(sort_spec) = sort {
match sort_spec.order {
SortOrder::Asc => ordering,
SortOrder::Desc => ordering.reverse(),
}
} else {
ordering
}
}

fn compare_hits(left: &SearchHit, right: &SearchHit, sort: &SortSpec) -> std::cmp::Ordering {
let left_value = left.source.get(&sort.field).and_then(comparable_value);
let right_value = right.source.get(&sort.field).and_then(comparable_value);
Expand Down
30 changes: 30 additions & 0 deletions rust/crates/cloudsearch-index/tests/coverage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,36 @@ async fn validate_search_request_rejects_from_exceeding_max() {
);
}

#[tokio::test]
async fn validate_search_request_rejects_search_after_without_sort() {
let temp_dir = TempDir::new().expect("temp dir");
let catalog = Arc::new(IndexCatalog::new(temp_dir.path()));
catalog.initialize().await.expect("init catalog");
let _metadata = catalog
.create_index(
"test",
CreateIndexRequest {
settings: IndexSettings::default(),
..Default::default()
},
)
.await
.expect("create index");
let handle = catalog.open_index("test").await.expect("open index");

// search_after without sort is invalid — cursor is meaningless without sort order
let request = SearchRequest {
search_after: Some(vec![serde_json::json!(1.0), serde_json::json!("doc123")]),
..Default::default()
};

let result = handle.validate_search_request(&request);
assert!(
result.is_err(),
"search_after without sort field should be rejected"
);
}

#[tokio::test]
async fn highlight_positions_case_insensitive() {
// Index doc with mixed-case text, search for lowercase term.
Expand Down