Skip to content

Commit b788c41

Browse files
committed
add KNearestNeighbor
1 parent fc668e9 commit b788c41

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

Sources/ElasticsearchQueryBuilder/Components.swift

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ public struct RootComponent<Component: DictComponent>: RootQueryable, DictCompon
3030
}
3131
}
3232

33+
public struct EmptyArrayComponent: ArrayComponent {
34+
public func makeArray() -> [QueryDict] { [] }
35+
}
36+
3337
/// Namespace for `@ElasticsearchQueryBuilder` components
3438
public enum esb {}
3539

@@ -184,6 +188,45 @@ extension esb {
184188
}
185189
}
186190

191+
/// Adds `knn` block to the query syntax.
192+
public struct KNearestNeighbor<Component: ArrayComponent>: DictComponent {
193+
let field: String
194+
let vector: [Double]
195+
let options: QueryDict
196+
var filter: Component
197+
public init(
198+
_ field: String,
199+
_ vector: [Double],
200+
options: () -> QueryDict = { [:] },
201+
@QueryArrayBuilder filter: () -> Component
202+
) {
203+
self.field = field
204+
self.vector = vector
205+
self.options = options()
206+
self.filter = filter()
207+
}
208+
public init(
209+
_ field: String,
210+
_ vector: [Double],
211+
options: () -> QueryDict = { [:] }
212+
) where Component == EmptyArrayComponent {
213+
self.field = field
214+
self.vector = vector
215+
self.options = options()
216+
self.filter = EmptyArrayComponent()
217+
}
218+
public func makeDict() -> QueryDict {
219+
var dict: QueryDict = self.options
220+
dict["field"] = .string(self.field)
221+
dict["query_vector"] = .array(self.vector)
222+
let filterValues = self.filter.makeCompactArray()
223+
if !filterValues.isEmpty {
224+
dict["filter"] = .array(filterValues.map(QueryValue.dict))
225+
}
226+
return [ "knn" : .dict(dict) ]
227+
}
228+
}
229+
187230
/// Adds `function_score` block to the query syntax.
188231
public struct FunctionScore<Component: DictComponent>: DictComponent {
189232
var component: Component

Tests/ElasticsearchQueryBuilderTests/ComponentTests.swift

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,49 @@ final class BoolTests: XCTestCase {
175175
}
176176
}
177177

178+
final class KNearestNeighborTests: XCTestCase {
179+
func testBuildBasic() throws {
180+
@ElasticsearchQueryBuilder func build() -> some esb.QueryDSL {
181+
esb.KNearestNeighbor("vector_field", [1,2,3])
182+
}
183+
XCTAssertNoDifference(build().makeQuery(), [
184+
"knn": [
185+
"field": "vector_field",
186+
"query_vector": [1.0, 2.0, 3.0],
187+
]
188+
])
189+
}
190+
func testBuildWithOptionsAndFilter() throws {
191+
@ElasticsearchQueryBuilder func build() -> some esb.QueryDSL {
192+
esb.KNearestNeighbor("vector_field", [1,2,3]) {
193+
[
194+
"k": 5
195+
]
196+
} filter: {
197+
esb.Key("match_bool_prefix") {
198+
[
199+
"message": "quick brown f"
200+
]
201+
}
202+
}
203+
}
204+
XCTAssertNoDifference(build().makeQuery(), [
205+
"knn": [
206+
"field": "vector_field",
207+
"query_vector": [1.0, 2.0, 3.0],
208+
"k": 5,
209+
"filter": [
210+
[
211+
"match_bool_prefix": [
212+
"message": "quick brown f"
213+
]
214+
]
215+
]
216+
]
217+
])
218+
}
219+
}
220+
178221
final class FunctionScoreTests: XCTestCase {
179222
func testBuild() throws {
180223
@ElasticsearchQueryBuilder func build() -> some esb.QueryDSL {

0 commit comments

Comments
 (0)