Skip to content

Commit df0fb34

Browse files
SQL-3008: Support $rankFusion Operator in Schema Derivation (#110)
* SQL-3008 - signed commits * Nest score details in if-statement * Rename RankFusionPipeline to RankFusionInput to match docs * Address schema_derivation.rs comments * Replace pipeline functions with macros. Format code * Fix failing test * Make scoreDetails optional. Add test for deduplicating match + sort pipelines * Additional stage tests * Use if-let to fix clippy warnings * Add a starting schema to type narrowing test for validation * Remove extraneous test comments * Simplify if let into a single conditional * Use merge() instead of union() for pipeline and merge schemas * Address clippy errors. Convert nested match to if-let * Replace fold() with for-loop. Add error handling test case. * Refactor for-loop into try_fold * Clean up tests. Move macros into rank_fusion module * Format JSON in tests. destructure kv-pair in try_fold * Update agg-ast/schema_derivation/src/schema_derivation.rs Co-authored-by: Jonathan Chemburkar <jonathan.chemburkar@mongodb.com> * Remove macro-export. Format stage test json --------- Co-authored-by: Jonathan Chemburkar <jonathan.chemburkar@mongodb.com>
1 parent 91160d8 commit df0fb34

File tree

5 files changed

+581
-2
lines changed

5 files changed

+581
-2
lines changed

agg-ast/ast/src/definitions.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ pub enum Stage {
7878
Sample(Sample),
7979
#[serde(rename = "$unionWith")]
8080
UnionWith(UnionWith),
81+
#[serde(rename = "$rankFusion")]
82+
RankFusion(RankFusion),
8183

8284
// Search stages
8385
#[serde(rename = "$graphLookup")]
@@ -620,6 +622,23 @@ pub struct Bucket {
620622
pub default: Option<Bson>,
621623
pub output: Option<LinkedHashMap<String, Expression>>,
622624
}
625+
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
626+
#[serde(rename_all = "camelCase")]
627+
pub struct RankFusion {
628+
pub input: RankFusionInput,
629+
pub combination: Option<RankFusionCombination>,
630+
pub score_details: Option<bool>,
631+
}
632+
633+
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
634+
pub struct RankFusionInput {
635+
pub pipelines: LinkedHashMap<String, Vec<Stage>>,
636+
}
637+
638+
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
639+
pub struct RankFusionCombination {
640+
pub weights: LinkedHashMap<String, f64>,
641+
}
623642

624643
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
625644
#[serde(rename_all = "camelCase")]
@@ -1783,6 +1802,7 @@ impl Stage {
17831802
Stage::GeoNear(_) => "$geoNear",
17841803
Stage::Sample(_) => "$sample",
17851804
Stage::UnionWith(_) => "$unionWith",
1805+
Stage::RankFusion(_) => "$rankFusion",
17861806
Stage::GraphLookup(_) => "$graphLookup",
17871807
Stage::AtlasSearchStage(_) => "<Atlas search stage>",
17881808
}

agg-ast/ast/src/serde_test.rs

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,6 +1605,167 @@ mod stage_test {
16051605
}}"#
16061606
);
16071607
}
1608+
mod rank_fusion {
1609+
1610+
use crate::definitions::Stage::{AtlasSearchStage, Match, Sort};
1611+
use crate::definitions::{
1612+
AtlasSearchStage::{Search, VectorSearch},
1613+
Expression,
1614+
Expression::Literal,
1615+
LiteralValue, MatchBinaryOp, MatchExpression, MatchField, MatchStage, RankFusion,
1616+
RankFusionCombination, RankFusionInput, Ref, Stage,
1617+
Stage::Limit,
1618+
};
1619+
use crate::map;
1620+
1621+
macro_rules! vector_pipeline {
1622+
() => {
1623+
vec![Stage::AtlasSearchStage(VectorSearch(Box::new(
1624+
Expression::Document(map! {
1625+
"index".to_string() => Literal(LiteralValue::String("hybrid-vector-search".to_string())),
1626+
"path".to_string() => Literal(LiteralValue::String("plot_embedding_voyage_3_large".to_string())),
1627+
"queryVector".to_string() => Expression::Array(vec![Literal(LiteralValue::Double(10.6)), Expression::Literal(LiteralValue::Double(60.5))]),
1628+
"numCandidates".to_string() => Literal(LiteralValue::Int32(100)),
1629+
}),
1630+
)))]
1631+
};
1632+
}
1633+
1634+
macro_rules! text_search_pipeline {
1635+
() => {
1636+
vec![Stage::AtlasSearchStage(
1637+
Search(Box::new(Expression::Document(
1638+
map! {
1639+
"index".to_string() => Literal(LiteralValue::String("hybrid-full-text-search".to_string())),
1640+
"phrase".to_string() => Expression::Document(map! {
1641+
"query".to_string() => Literal(LiteralValue::String("star wars".to_string())),
1642+
"path".to_string() => Literal(LiteralValue::String("title".to_string())),
1643+
})
1644+
},
1645+
)))
1646+
), Limit(20)]
1647+
};
1648+
}
1649+
1650+
test_serde_stage!(
1651+
rank_fusion_single_pipeline,
1652+
expected = Stage::RankFusion(RankFusion {
1653+
input: RankFusionInput {
1654+
pipelines: map! {
1655+
"searchOne".to_string() => vector_pipeline!()
1656+
},
1657+
},
1658+
combination: None,
1659+
score_details: None
1660+
}),
1661+
input = r#"stage: {"$rankFusion": {
1662+
"input": { "pipelines": { searchOne: [{ "$vectorSearch" : {"index" : "hybrid-vector-search", "path" : "plot_embedding_voyage_3_large", "queryVector": [10.6, 60.5], "numCandidates": 100} }] } },
1663+
}}"#
1664+
);
1665+
1666+
test_serde_stage!(
1667+
rank_fusion_multiple_pipelines_with_weights,
1668+
expected = Stage::RankFusion(RankFusion {
1669+
input: RankFusionInput {
1670+
pipelines: map! {
1671+
"vectorPipeline".to_string() => vector_pipeline!(),
1672+
"fullTextPipeline".to_string() => text_search_pipeline!(),
1673+
},
1674+
},
1675+
combination: Some(RankFusionCombination {
1676+
weights: map! {
1677+
"vectorPipeline".to_string() => 0.5,
1678+
"fullTextPipeline".to_string() => 0.5
1679+
}
1680+
}),
1681+
score_details: Some(true)
1682+
}),
1683+
input = r#"stage: {
1684+
"$rankFusion": {
1685+
"input": {
1686+
pipelines: {
1687+
vectorPipeline: [
1688+
{
1689+
"$vectorSearch": {
1690+
"index": "hybrid-vector-search",
1691+
"path": "plot_embedding_voyage_3_large",
1692+
"queryVector": [10.6, 60.5],
1693+
"numCandidates": 100,
1694+
}
1695+
}
1696+
],
1697+
fullTextPipeline: [
1698+
{
1699+
"$search": {
1700+
"index": "hybrid-full-text-search",
1701+
"phrase": {
1702+
"query": "star wars",
1703+
"path": "title"
1704+
}
1705+
}
1706+
},
1707+
{ "$limit": 20 }
1708+
]
1709+
}
1710+
},
1711+
"combination": {
1712+
weights: {
1713+
vectorPipeline: 0.5,
1714+
fullTextPipeline: 0.5
1715+
}
1716+
},
1717+
"scoreDetails": true
1718+
}
1719+
}"#
1720+
);
1721+
1722+
test_serde_stage!(
1723+
pipelines_are_deduplicated,
1724+
expected = Stage::RankFusion(RankFusion {
1725+
input: RankFusionInput {
1726+
pipelines: map! {
1727+
"searchOne".to_string() => vec![AtlasSearchStage(Search(Box::new(
1728+
Expression::Document(map! {
1729+
"index".to_string() => Literal(LiteralValue::String("hybrid-full-text-search".to_string())),
1730+
"phrase".to_string() => Expression::Document(map! {
1731+
"query".to_string() => Literal(LiteralValue::String("adventure".to_string())),
1732+
"path".to_string() => Literal(LiteralValue::String("plot".to_string()))
1733+
}),
1734+
1735+
}),
1736+
))),
1737+
Match(MatchStage {
1738+
expr: vec![MatchExpression::Field(MatchField {
1739+
field: Ref::FieldRef("metacritic".to_string()),
1740+
ops: map! { MatchBinaryOp::Gt => bson::Bson::Int32(75) }
1741+
})]
1742+
}) ,
1743+
Sort(map! {"title".to_string() => 1})]
1744+
},
1745+
},
1746+
combination: None,
1747+
score_details: Some(false)
1748+
}),
1749+
input = r#"stage: { "$rankFusion" : {
1750+
"input" : {
1751+
"pipelines" : {
1752+
searchOne: [
1753+
{ "$search": { "index": "hybrid-full-text-search", "phrase": { "query": "adventure", "path": "plot"}}},
1754+
{ "$match": { "genres": "Western", "year": { "$lt": 1980 }}},
1755+
{ "$sort": { "runtime": 1}
1756+
}],
1757+
searchOne: [
1758+
{ "$search": { "index": "hybrid-full-text-search", "phrase": { "query": "adventure","path": "plot"}}},
1759+
{ "$match": { "metacritic": { "$gt": 75 }}},
1760+
{ "$sort": { "title": 1}
1761+
}]
1762+
}
1763+
},
1764+
"scoreDetails": false
1765+
}
1766+
}"#
1767+
);
1768+
}
16081769

16091770
mod densify {
16101771
use crate::definitions::{Densify, DensifyRange, DensifyRangeBounds, Stage};

agg-ast/schema_derivation/src/schema_derivation.rs

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ use crate::{
77
use agg_ast::definitions::{
88
AtlasSearchStage, Bucket, BucketAuto, ConciseSubqueryLookup, Densify, Documents,
99
EqualityLookup, Expression, Fill, FillOutput, GraphLookup, Group, LiteralValue, Lookup,
10-
LookupFrom, Namespace, ProjectItem, ProjectStage, Ref, SetWindowFields, Stage, SubqueryLookup,
11-
TaggedOperator, UnionWith, Unset, UntaggedOperator, UntaggedOperatorName, Unwind,
10+
LookupFrom, Namespace, ProjectItem, ProjectStage, RankFusion, Ref, SetWindowFields, Stage,
11+
SubqueryLookup, TaggedOperator, UnionWith, Unset, UntaggedOperator, UntaggedOperatorName,
12+
Unwind,
1213
};
1314
use linked_hash_map::LinkedHashMap;
1415
use mongosql::{
@@ -433,6 +434,61 @@ impl DeriveSchema for Stage {
433434
}
434435
}
435436

437+
fn rank_fusion_derive_schema(
438+
rank_fusion: &RankFusion,
439+
state: &mut ResultSetState,
440+
) -> Result<Schema> {
441+
// Derive the schema for each pipeline and union them together
442+
let mut unioned_schema_pipelines: Schema = rank_fusion
443+
.input
444+
.pipelines
445+
.iter()
446+
.try_fold(Schema::Unsat, |acc, (_, pipeline)| {
447+
let derived_pipeline_schema =
448+
derive_schema_for_pipeline(pipeline.clone(), None, &mut state.clone())?;
449+
450+
Ok(acc.union(&derived_pipeline_schema))
451+
})?;
452+
453+
// 2. If score_details is true, add scoreDetails schema to the overall schema
454+
if let Some(true) = rank_fusion.score_details {
455+
let score_details_document: Document = Document {
456+
keys: map! {
457+
"scoreDetails".to_string() => Schema::Document(Document {
458+
keys: map! {
459+
"value".to_string() => Schema::Atomic(Atomic::Decimal),
460+
"description".to_string() => Schema::Atomic(Atomic::String),
461+
"details".to_string() => Schema::Array(Box::new(Schema::Document(Document {
462+
keys: map! {
463+
"inputPipelineName".to_string() => Schema::Atomic(Atomic::String),
464+
"rank".to_string() => Schema::Atomic(Atomic::Integer),
465+
"weight".to_string() => Schema::Atomic(Atomic::Integer),
466+
"value".to_string() => Schema::Atomic(Atomic::Decimal),
467+
"details".to_string() => Schema::Array(Box::new(Schema::Any)),
468+
},
469+
required: set!("inputPipelineName".to_string(), "rank".to_string(),),
470+
..Default::default()
471+
})))
472+
},
473+
required: set!("value".to_string(), "description".to_string(),),
474+
..Default::default()
475+
})
476+
},
477+
required: set!("scoreDetails".to_string()),
478+
additional_properties: false,
479+
jaccard_index: None,
480+
};
481+
482+
// Merge the pipeline schema and score details schema together
483+
if let Schema::Document(ref pipeline_doc) = unioned_schema_pipelines {
484+
unioned_schema_pipelines =
485+
Schema::Document(pipeline_doc.clone().merge(score_details_document));
486+
}
487+
}
488+
489+
Ok(unioned_schema_pipelines)
490+
}
491+
436492
/// bucket_derive_schema derives the schema for a $bucket stage. The schema is defined by the output field,
437493
/// and contains _id as well.
438494
fn bucket_derive_schema(bucket: &Bucket, state: &mut ResultSetState) -> Result<Schema> {
@@ -1128,6 +1184,7 @@ impl DeriveSchema for Stage {
11281184
Stage::Lookup(l) => lookup_derive_schema(l, state),
11291185
Stage::Match(ref m) => m.derive_schema(state),
11301186
Stage::Project(p) => project_derive_schema(p, state),
1187+
Stage::RankFusion(rf) => rank_fusion_derive_schema(rf, state),
11311188
Stage::Redact(_) => Ok(state.result_set_schema.to_owned()),
11321189
Stage::ReplaceWith(r) => r
11331190
.to_owned()

0 commit comments

Comments
 (0)