Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 181 additions & 21 deletions rust/crates/cloudsearch-index/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -783,13 +783,39 @@ impl IndexHandle {
let query = request.query.as_ref().unwrap_or(&SearchQuery::MatchAll);
let now = Utc::now();

// BM25 parameters
let k1 = 1.2f32;
let b = 0.75f32;
let n_docs = self.searchable_documents.len().max(1);

// Build IDF map: for each query term, compute IDF = log((N-df+0.5)/(df+0.5))
// We sum DF across all segment readers to get total document frequency.
// Deduplicate terms via BTreeSet to avoid redundant IDF lookups.
let target_field = extract_target_field(query);
let mut idf_map: std::collections::BTreeMap<String, f32> =
std::collections::BTreeMap::new();
let query_terms: std::collections::BTreeSet<String> =
extract_query_terms(query, "").into_iter().collect();
for term in &query_terms {
let mut total_df = 0usize;
for reader in &self.positions_readers {
if let Some(pl) = reader.get(term) {
total_df += pl.docs.len();
}
}
let idf = bm25_idf(total_df, n_docs);
idf_map.insert(term.clone(), idf);
}
let avg_field_len =
compute_avg_field_length(&self.searchable_documents, &target_field).max(1.0);

let mut scored: Vec<(f32, &IndexDocument)> = self
.searchable_documents
.iter()
.filter(|(_, doc)| !self.is_expired(&doc.id, now))
.filter_map(|(_, doc)| {
let doc_id_hash = hash_doc_id(&doc.id);
score_query(doc, query, doc_id_hash, &self.positions_readers).map(|s| (s, doc))
score_query(doc, query, doc_id_hash, &self.positions_readers, &idf_map, avg_field_len, k1, b).map(|s| (s, doc))
})
.collect();

Expand Down Expand Up @@ -1717,11 +1743,88 @@ fn infer_field_type(field: &str, value: &serde_json::Value) -> Result<Option<Fie
})
}

/// BM25 IDF formula: log((N - df + 0.5) / (df + 0.5)).
/// Returns 0 for df > N (shouldn't happen in practice).
fn bm25_idf(df: usize, n_docs: usize) -> f32 {
if df == 0 {
return 0.0;
}
let n = n_docs as f32;
let df = df as f32;
((n - df + 0.5) / (df + 0.5)).ln().max(0.0)
}

/// Extract the target field name from a SearchQuery for BM25 field-length normalization.
/// Defaults to "content" for queries without a direct field mapping (MatchAll, Bool, Range, etc.).
fn extract_target_field(query: &SearchQuery) -> String {
match query {
SearchQuery::Match(mq) => mq.field.clone(),
SearchQuery::Phrase(pq) => pq.field.clone(),
SearchQuery::Term(tq) => tq.field.clone(),
_ => "content".to_string(),
}
}

/// Collect all unique query terms from a SearchQuery (for match/phrase/term queries).
fn extract_query_terms(query: &SearchQuery, target_field: &str) -> Vec<String> {
match query {
SearchQuery::Match(mq) if mq.field == target_field => {
tokenize(&mq.value)
}
SearchQuery::Phrase(pq) if pq.field == target_field => {
tokenize(&pq.value)
}
SearchQuery::Term(tq) if tq.fuzziness.is_none() => {
// For exact term queries, use the term value as-is (already lowercase normalization)
if let serde_json::Value::String(s) = &tq.value {
vec![s.to_lowercase()]
} else {
vec![]
}
}
SearchQuery::Bool(bq) => {
let mut terms = Vec::new();
for q in bq.must.iter().chain(bq.should.iter()) {
terms.extend(extract_query_terms(q, target_field));
}
terms
}
_ => vec![],
}
}

/// Compute average field length across all documents in the index.
fn compute_avg_field_length(
documents: &std::collections::BTreeMap<String, IndexDocument>,
field: &str,
) -> f32 {
let mut total_len = 0usize;
let mut count = 0usize;
for doc in documents.values() {
if let Some(val) = doc.source.get(field)
&& let Some(s) = val.as_str()
{
total_len += tokenize(s).len();
count += 1;
}
}
if count == 0 {
1.0
} else {
total_len as f32 / count as f32
}
}

#[allow(clippy::too_many_arguments)]
fn score_query(
document: &IndexDocument,
query: &SearchQuery,
doc_id: u64,
positions_readers: &[cloudsearch_storage::inverted_index::PositionsReader],
idf_map: &std::collections::BTreeMap<String, f32>,
avg_field_len: f32,
k1: f32,
b: f32,
) -> Option<f32> {
match query {
SearchQuery::MatchAll => Some(1.0),
Expand All @@ -1732,31 +1835,79 @@ fn score_query(
SearchQuery::Terms(terms) => matches_terms_query(document, terms).then_some(1.0),
SearchQuery::Range(range) => matches_range_query(document, range).then_some(1.0),
SearchQuery::Bool(bool_query) => {
score_bool_query(document, bool_query, doc_id, positions_readers)
score_bool_query(document, bool_query, doc_id, positions_readers, idf_map, avg_field_len, k1, b)
}
SearchQuery::Prefix(prefix) => matches_prefix_query(document, prefix).then_some(1.0),
SearchQuery::Wildcard(wc) => matches_wildcard_query(document, wc).then_some(1.0),
SearchQuery::Match(mq) => score_match_query(document, mq),
SearchQuery::Match(mq) => {
score_match_query(document, mq, doc_id, positions_readers, idf_map, avg_field_len, k1, b)
}
SearchQuery::Phrase(phrase) => {
score_phrase_query(document, phrase, doc_id, positions_readers)
}
}
}

#[allow(clippy::cast_precision_loss)]
fn score_match_query(document: &IndexDocument, query: &MatchQuery) -> Option<f32> {
#[allow(clippy::cast_precision_loss, clippy::too_many_arguments)]
fn score_match_query(
document: &IndexDocument,
query: &MatchQuery,
doc_id: u64,
positions_readers: &[cloudsearch_storage::inverted_index::PositionsReader],
idf_map: &std::collections::BTreeMap<String, f32>,
avg_field_len: f32,
k1: f32,
b: f32,
) -> Option<f32> {
let field_str = document.source.get(&query.field)?.as_str()?;
let field_tokens = tokenize(field_str);
let query_tokens = tokenize(&query.value);
if query_tokens.is_empty() {
return None;
}
let field_set: std::collections::HashSet<&String> = field_tokens.iter().collect();
let matched = query_tokens
.iter()
.filter(|t| field_set.contains(t))
.count();
(matched > 0).then(|| matched as f32 / query_tokens.len() as f32)
let doc_len = field_tokens.len();

// For each query token, look up DF from positions readers and TF from the posting
// for this specific document, then compute BM25 and sum
let mut total_score = 0.0f32;
let mut matched = 0;

for token in &query_tokens {
// Find TF for this document across all segment readers
let mut tf = 0u32;
for reader in positions_readers {
if let Some(pl) = reader.get(token) {
// Binary search for this doc_id
if let Ok(idx) = pl.docs.binary_search_by(|p| p.doc_id.cmp(&doc_id)) {
tf += pl.docs[idx].term_freq;
}
}
}

if tf == 0 {
// Token not in inverted index for this doc — check source field directly
// (doc may not have been flushed yet, only in WAL)
if field_tokens.contains(token) {
// Count occurrences in field_tokens
tf = field_tokens.iter().filter(|t| *t == token).count() as u32;
} else {
continue;
}
}

// Use precomputed IDF if available, otherwise compute from df=matched docs
let idf = idf_map.get(token).copied().unwrap_or(1.0);
let term_score = idf * (tf as f32 * (k1 + 1.0))
/ (tf as f32 + k1 * (1.0 - b + b * doc_len as f32 / avg_field_len.max(1.0)));
total_score += term_score;
matched += 1;
}

if matched == 0 {
None
} else {
Some(total_score / matched as f32)
}
}

/// Score a phrase query by checking if query terms appear consecutively in document text.
Expand Down Expand Up @@ -2089,33 +2240,37 @@ fn build_wildcard_regex(pattern: &str) -> Option<Regex> {
Regex::new(&format!("^{regex_pattern}$")).ok()
}

#[allow(clippy::cast_precision_loss)]
#[allow(clippy::cast_precision_loss, clippy::too_many_arguments)]
fn score_bool_query(
document: &IndexDocument,
bool_query: &BoolQuery,
doc_id: u64,
positions_readers: &[cloudsearch_storage::inverted_index::PositionsReader],
idf_map: &std::collections::BTreeMap<String, f32>,
avg_field_len: f32,
k1: f32,
b: f32,
) -> Option<f32> {
// Evaluate each clause group once and store the scores.
let must_scores: Vec<Option<f32>> = bool_query
.must
.iter()
.map(|q| score_query(document, q, doc_id, positions_readers))
.map(|q| score_query(document, q, doc_id, positions_readers, idf_map, avg_field_len, k1, b))
.collect();
let filter_scores: Vec<Option<f32>> = bool_query
.filter
.iter()
.map(|q| score_query(document, q, doc_id, positions_readers))
.map(|q| score_query(document, q, doc_id, positions_readers, idf_map, avg_field_len, k1, b))
.collect();
let must_not_scores: Vec<Option<f32>> = bool_query
.must_not
.iter()
.map(|q| score_query(document, q, doc_id, positions_readers))
.map(|q| score_query(document, q, doc_id, positions_readers, idf_map, avg_field_len, k1, b))
.collect();
let should_scores: Vec<Option<f32>> = bool_query
.should
.iter()
.map(|q| score_query(document, q, doc_id, positions_readers))
.map(|q| score_query(document, q, doc_id, positions_readers, idf_map, avg_field_len, k1, b))
.collect();

// All must clauses must match.
Expand Down Expand Up @@ -5439,9 +5594,11 @@ mod tests {
..Default::default()
});
assert_eq!(hello.hits.total, 2);
// query "hello" has 1 token; both docs match 1/1 = 1.0, so the tie-breaker (alphabetical id) applies
// With BM25 scoring, "hello" has df=2 (doc-1 and doc-2), n=3.
// IDF = max(0, ln((3-2+0.5)/(2+0.5))) ≈ 0, so score ≈ 0 unless TF is high.
// Both docs match "hello" with equal TF, so tie-breaker is alphabetical id.
assert_eq!(hello.hits.hits[0].id, "doc-1");
assert_eq!(hello.hits.hits[0].score, Some(1.0));
assert!(hello.hits.hits[0].score.unwrap_or(0.0) >= 0.0);

// Match "hello world" - both docs match 2/2 tokens = 1.0 (tie goes to lower doc id)
let both = handle.search(&SearchRequest {
Expand All @@ -5452,11 +5609,14 @@ mod tests {
..Default::default()
});
assert_eq!(both.hits.total, 2);
// Both match 2/2 = 1.0, tie-breaker is alphabetical: doc-1 < doc-2
// Both docs match 2 tokens. With BM25, doc-1 ("hello world" → 2 tokens, len=2)
// and doc-2 ("hello there world" → 3 tokens, len=3) get different scores even with
// identical term frequencies, due to field length normalization.
// Tie-breaker is alphabetical: doc-1 < doc-2.
assert_eq!(both.hits.hits[0].id, "doc-1");
assert_eq!(both.hits.hits[0].score, Some(1.0));
assert!(both.hits.hits[0].score.unwrap_or(0.0) >= 0.0);
assert_eq!(both.hits.hits[1].id, "doc-2");
assert_eq!(both.hits.hits[1].score, Some(1.0));
assert!(both.hits.hits[1].score.unwrap_or(0.0) >= 0.0);

// Match "xyz" finds nothing
let no_match = handle.search(&SearchRequest {
Expand Down
Loading