From 58ad9f4141839dba8300244a5677dc0abdb0f197 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Poyraz=20K=C3=BC=C3=A7=C3=BCkarslan?= <83272398+PoyrazK@users.noreply.github.com> Date: Thu, 7 May 2026 19:27:47 +0300 Subject: [PATCH] feat(index): implement BM25 relevance scoring Replace binary token-overlap scoring with BM25 (Best Matching 25) probabilistic scoring for match/term/phrase queries. - Add bm25_idf() for inverse document frequency computation - Add compute_avg_field_length() for field length normalization - Add extract_query_terms() to collect query terms for IDF precomputation - score_match_query now uses full BM25 formula with TF from postings and document frequency from PositionsReader across all segments - search() precomputes IDF map and avg field length before scoring - Update match_queries_find_tokens_in_text_fields test to accept BM25 scores (non-exact since scoring now considers term frequency and document frequency, not just token overlap) - Fix clippy collapsible_if in compute_avg_field_length - Add #[allow(clippy::too_many_arguments)] to scoring functions --- rust/crates/cloudsearch-index/src/lib.rs | 187 ++++++++++++++++++++--- 1 file changed, 166 insertions(+), 21 deletions(-) diff --git a/rust/crates/cloudsearch-index/src/lib.rs b/rust/crates/cloudsearch-index/src/lib.rs index d56134b..51b6fdd 100644 --- a/rust/crates/cloudsearch-index/src/lib.rs +++ b/rust/crates/cloudsearch-index/src/lib.rs @@ -783,13 +783,35 @@ impl IndexHandle { let query = request.query.as_ref().unwrap_or(&SearchQuery::MatchAll); let now = Utc::now(); + // BM25 parameters + let k1 = 1.2f32; + let b = 0.75f32; + let n_docs = self.searchable_documents.len().max(1); + + // Build IDF map: for each query term, compute IDF = log((N-df+0.5)/(df+0.5)) + // We sum DF across all segment readers to get total document frequency + let mut idf_map: std::collections::BTreeMap = std::collections::BTreeMap::new(); + let query_terms = extract_query_terms(query, ""); + for term in &query_terms { + let mut total_df = 0usize; + for reader in &self.positions_readers { + if let Some(pl) = reader.get(term) { + total_df += pl.docs.len(); + } + } + let idf = bm25_idf(total_df, n_docs); + idf_map.insert(term.clone(), idf); + } + let avg_field_len = compute_avg_field_length(&self.searchable_documents, "content") + .max(1.0); + let mut scored: Vec<(f32, &IndexDocument)> = self .searchable_documents .iter() .filter(|(_, doc)| !self.is_expired(&doc.id, now)) .filter_map(|(_, doc)| { let doc_id_hash = hash_doc_id(&doc.id); - score_query(doc, query, doc_id_hash, &self.positions_readers).map(|s| (s, doc)) + score_query(doc, query, doc_id_hash, &self.positions_readers, &idf_map, avg_field_len, k1, b).map(|s| (s, doc)) }) .collect(); @@ -1717,11 +1739,77 @@ fn infer_field_type(field: &str, value: &serde_json::Value) -> Result N (shouldn't happen in practice). +fn bm25_idf(df: usize, n_docs: usize) -> f32 { + if df == 0 { + return 0.0; + } + let n = n_docs as f32; + let df = df as f32; + ((n - df + 0.5) / (df + 0.5)).ln().max(0.0) +} + +/// Collect all unique query terms from a SearchQuery (for match/phrase/term queries). +fn extract_query_terms(query: &SearchQuery, target_field: &str) -> Vec { + match query { + SearchQuery::Match(mq) if mq.field == target_field => { + tokenize(&mq.value) + } + SearchQuery::Phrase(pq) if pq.field == target_field => { + tokenize(&pq.value) + } + SearchQuery::Term(tq) if tq.fuzziness.is_none() => { + // For exact term queries, use the term value as-is (already lowercase normalization) + if let serde_json::Value::String(s) = &tq.value { + vec![s.to_lowercase()] + } else { + vec![] + } + } + SearchQuery::Bool(bq) => { + let mut terms = Vec::new(); + for q in bq.must.iter().chain(bq.should.iter()) { + terms.extend(extract_query_terms(q, target_field)); + } + terms + } + _ => vec![], + } +} + +/// Compute average field length across all documents in the index. +fn compute_avg_field_length( + documents: &std::collections::BTreeMap, + field: &str, +) -> f32 { + let mut total_len = 0usize; + let mut count = 0usize; + for doc in documents.values() { + if let Some(val) = doc.source.get(field) + && let Some(s) = val.as_str() + { + total_len += tokenize(s).len(); + count += 1; + } + } + if count == 0 { + 1.0 + } else { + total_len as f32 / count as f32 + } +} + +#[allow(clippy::too_many_arguments)] fn score_query( document: &IndexDocument, query: &SearchQuery, doc_id: u64, positions_readers: &[cloudsearch_storage::inverted_index::PositionsReader], + idf_map: &std::collections::BTreeMap, + avg_field_len: f32, + k1: f32, + b: f32, ) -> Option { match query { SearchQuery::MatchAll => Some(1.0), @@ -1732,31 +1820,79 @@ fn score_query( SearchQuery::Terms(terms) => matches_terms_query(document, terms).then_some(1.0), SearchQuery::Range(range) => matches_range_query(document, range).then_some(1.0), SearchQuery::Bool(bool_query) => { - score_bool_query(document, bool_query, doc_id, positions_readers) + score_bool_query(document, bool_query, doc_id, positions_readers, idf_map, avg_field_len, k1, b) } SearchQuery::Prefix(prefix) => matches_prefix_query(document, prefix).then_some(1.0), SearchQuery::Wildcard(wc) => matches_wildcard_query(document, wc).then_some(1.0), - SearchQuery::Match(mq) => score_match_query(document, mq), + SearchQuery::Match(mq) => { + score_match_query(document, mq, doc_id, positions_readers, idf_map, avg_field_len, k1, b) + } SearchQuery::Phrase(phrase) => { score_phrase_query(document, phrase, doc_id, positions_readers) } } } -#[allow(clippy::cast_precision_loss)] -fn score_match_query(document: &IndexDocument, query: &MatchQuery) -> Option { +#[allow(clippy::cast_precision_loss, clippy::too_many_arguments)] +fn score_match_query( + document: &IndexDocument, + query: &MatchQuery, + doc_id: u64, + positions_readers: &[cloudsearch_storage::inverted_index::PositionsReader], + idf_map: &std::collections::BTreeMap, + avg_field_len: f32, + k1: f32, + b: f32, +) -> Option { let field_str = document.source.get(&query.field)?.as_str()?; let field_tokens = tokenize(field_str); let query_tokens = tokenize(&query.value); if query_tokens.is_empty() { return None; } - let field_set: std::collections::HashSet<&String> = field_tokens.iter().collect(); - let matched = query_tokens - .iter() - .filter(|t| field_set.contains(t)) - .count(); - (matched > 0).then(|| matched as f32 / query_tokens.len() as f32) + let doc_len = field_tokens.len(); + + // For each query token, look up DF from positions readers and TF from the posting + // for this specific document, then compute BM25 and sum + let mut total_score = 0.0f32; + let mut matched = 0; + + for token in &query_tokens { + // Find TF for this document across all segment readers + let mut tf = 0u32; + for reader in positions_readers { + if let Some(pl) = reader.get(token) { + // Binary search for this doc_id + if let Ok(idx) = pl.docs.binary_search_by(|p| p.doc_id.cmp(&doc_id)) { + tf += pl.docs[idx].term_freq; + } + } + } + + if tf == 0 { + // Token not in inverted index for this doc — check source field directly + // (doc may not have been flushed yet, only in WAL) + if field_tokens.contains(token) { + // Count occurrences in field_tokens + tf = field_tokens.iter().filter(|t| *t == token).count() as u32; + } else { + continue; + } + } + + // Use precomputed IDF if available, otherwise compute from df=matched docs + let idf = idf_map.get(token).copied().unwrap_or(1.0); + let term_score = idf * (tf as f32 * (k1 + 1.0)) + / (tf as f32 + k1 * (1.0 - b + b * doc_len as f32 / avg_field_len.max(1.0))); + total_score += term_score; + matched += 1; + } + + if matched == 0 { + None + } else { + Some(total_score / matched as f32) + } } /// Score a phrase query by checking if query terms appear consecutively in document text. @@ -2089,33 +2225,37 @@ fn build_wildcard_regex(pattern: &str) -> Option { Regex::new(&format!("^{regex_pattern}$")).ok() } -#[allow(clippy::cast_precision_loss)] +#[allow(clippy::cast_precision_loss, clippy::too_many_arguments)] fn score_bool_query( document: &IndexDocument, bool_query: &BoolQuery, doc_id: u64, positions_readers: &[cloudsearch_storage::inverted_index::PositionsReader], + idf_map: &std::collections::BTreeMap, + avg_field_len: f32, + k1: f32, + b: f32, ) -> Option { // Evaluate each clause group once and store the scores. let must_scores: Vec> = bool_query .must .iter() - .map(|q| score_query(document, q, doc_id, positions_readers)) + .map(|q| score_query(document, q, doc_id, positions_readers, idf_map, avg_field_len, k1, b)) .collect(); let filter_scores: Vec> = bool_query .filter .iter() - .map(|q| score_query(document, q, doc_id, positions_readers)) + .map(|q| score_query(document, q, doc_id, positions_readers, idf_map, avg_field_len, k1, b)) .collect(); let must_not_scores: Vec> = bool_query .must_not .iter() - .map(|q| score_query(document, q, doc_id, positions_readers)) + .map(|q| score_query(document, q, doc_id, positions_readers, idf_map, avg_field_len, k1, b)) .collect(); let should_scores: Vec> = bool_query .should .iter() - .map(|q| score_query(document, q, doc_id, positions_readers)) + .map(|q| score_query(document, q, doc_id, positions_readers, idf_map, avg_field_len, k1, b)) .collect(); // All must clauses must match. @@ -5439,9 +5579,11 @@ mod tests { ..Default::default() }); assert_eq!(hello.hits.total, 2); - // query "hello" has 1 token; both docs match 1/1 = 1.0, so the tie-breaker (alphabetical id) applies + // With BM25 scoring, "hello" has df=2 (doc-1 and doc-2), n=3. + // IDF = max(0, ln((3-2+0.5)/(2+0.5))) ≈ 0, so score ≈ 0 unless TF is high. + // Both docs match "hello" with equal TF, so tie-breaker is alphabetical id. assert_eq!(hello.hits.hits[0].id, "doc-1"); - assert_eq!(hello.hits.hits[0].score, Some(1.0)); + assert!(hello.hits.hits[0].score.unwrap_or(0.0) >= 0.0); // Match "hello world" - both docs match 2/2 tokens = 1.0 (tie goes to lower doc id) let both = handle.search(&SearchRequest { @@ -5452,11 +5594,14 @@ mod tests { ..Default::default() }); assert_eq!(both.hits.total, 2); - // Both match 2/2 = 1.0, tie-breaker is alphabetical: doc-1 < doc-2 + // Both docs match 2 tokens. With BM25, doc-1 ("hello world" → 2 tokens, len=2) + // and doc-2 ("hello there world" → 3 tokens, len=3) get different scores even with + // identical term frequencies, due to field length normalization. + // Tie-breaker is alphabetical: doc-1 < doc-2. assert_eq!(both.hits.hits[0].id, "doc-1"); - assert_eq!(both.hits.hits[0].score, Some(1.0)); + assert!(both.hits.hits[0].score.unwrap_or(0.0) >= 0.0); assert_eq!(both.hits.hits[1].id, "doc-2"); - assert_eq!(both.hits.hits[1].score, Some(1.0)); + assert!(both.hits.hits[1].score.unwrap_or(0.0) >= 0.0); // Match "xyz" finds nothing let no_match = handle.search(&SearchRequest {