Skip to content

Commit d9fceaa

Browse files
committed
feat: ✨ add support for finding the k nearest neighbors
1 parent 68addb2 commit d9fceaa

File tree

5 files changed

+314
-1
lines changed

5 files changed

+314
-1
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/rishitc/go-kd-tree
22

3-
go 1.21.4
3+
go 1.22.5
44

55
require (
66
github.com/google/flatbuffers v24.3.25+incompatible

kdtree.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package kdtree
22

33
import (
4+
"container/heap"
45
"fmt"
56
"sort"
67
"strings"
@@ -408,6 +409,93 @@ func nearestNeighbor[T Comparable[T]](d int, v, nn *T, cd int, r *kdNode[T]) *T
408409
return nn
409410
}
410411

412+
func (t *KDTree[T]) KNearestNeighbor(value T, k int) []T {
413+
if t == nil || t.root == nil || t.sz < k {
414+
return nil
415+
}
416+
417+
pqRes := NewBoundedPriorityQueue[T](k)
418+
kNearestNeighbor(k, t.dimensions, &value, &pqRes, 0, t.root)
419+
420+
res := make([]T, 0, pqRes.Len())
421+
for pqRes.Len() != 0 {
422+
d := heap.Pop(&pqRes).(Item[T]).Data
423+
res = append(res, d)
424+
}
425+
426+
return res
427+
}
428+
429+
type direction bool
430+
431+
const (
432+
left direction = true
433+
right = false
434+
)
435+
436+
type nodeInfo[T Comparable[T]] struct {
437+
node *kdNode[T]
438+
dir direction
439+
}
440+
441+
func kNearestNeighbor[T Comparable[T]](k, d int, v *T, pq *BoundedPriorityQueue[T], cd int, r *kdNode[T]) {
442+
if r == nil {
443+
return
444+
}
445+
446+
ncd := cd
447+
448+
var path []nodeInfo[T]
449+
for r != nil {
450+
info := nodeInfo[T]{
451+
node: r,
452+
}
453+
if rel := (*v).Order(r.value, ncd); rel == Lesser {
454+
r = r.left
455+
info.dir = left
456+
} else {
457+
r = r.right
458+
info.dir = right
459+
}
460+
path = append(path, info)
461+
462+
ncd = (ncd + 1) % d
463+
}
464+
465+
ncd = (ncd - 1 + d) % d // Go back to the dimension used for splitting at the leaf node.
466+
for path, cn, cDir := popLast(path); cn != nil; path, cn, cDir = popLast(path) {
467+
currentDistance := (*v).Dist(cn.value)
468+
heap.Push(pq, Item[T]{
469+
Data: cn.value,
470+
Priority: currentDistance,
471+
})
472+
473+
if pq.Len() < pq.Capacity() || (*v).DistDim(cn.value, ncd) < getFarthestDistance(pq) {
474+
var next *kdNode[T]
475+
if cDir == left {
476+
next = cn.right
477+
} else {
478+
next = cn.left
479+
}
480+
kNearestNeighbor(k, d, v, pq, (ncd+1)%d, next)
481+
}
482+
ncd = (ncd - 1 + d) % d
483+
}
484+
}
485+
486+
func getFarthestDistance[T Comparable[T]](pq *BoundedPriorityQueue[T]) int {
487+
v := pq.Peek().(Item[T])
488+
return v.Priority
489+
}
490+
491+
func popLast[T Comparable[T]](arr []nodeInfo[T]) ([]nodeInfo[T], *kdNode[T], direction) {
492+
if len(arr) == 0 {
493+
return arr, nil, left
494+
}
495+
li := len(arr) - 1
496+
return arr[:li], arr[li].node, arr[li].dir
497+
}
498+
411499
func findMin[T Comparable[T]](d, tcd, cd int, r *kdNode[T]) *T {
412500
if r == nil {
413501
return nil

kdtree_2d_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,61 @@ func Test2DNearestNeighbor5(t *testing.T) {
251251
}
252252
}
253253

254+
func tensor2DSortFunc(a, b Tensor2D) int {
255+
if a[0] != b[0] {
256+
return a[0] - b[0]
257+
} else {
258+
return a[1] - b[1]
259+
}
260+
}
261+
262+
func Test2DKNearestNeighbor1(t *testing.T) {
263+
const dimensions = 2
264+
type input struct {
265+
p Tensor2D
266+
k int
267+
}
268+
ps := []Tensor2D{
269+
{50, 50},
270+
{10, 25},
271+
{40, 20},
272+
{25, 80},
273+
{70, 70},
274+
{60, 10},
275+
{60, 90},
276+
}
277+
tree := kdtree.NewKDTreeWithValues(dimensions, ps)
278+
testTable := map[string]struct {
279+
input input
280+
expected []Tensor2D
281+
}{
282+
"Find the 2 closest neighbors to a point that is not in the KD tree.": {
283+
input: input{p: [2]int{25, 25}, k: 2},
284+
expected: []Tensor2D{{40, 20}, {10, 25}},
285+
},
286+
"The closest neighbor to a point that is in the KD tree.": {
287+
input: input{p: [2]int{60, 90}, k: 1},
288+
expected: []Tensor2D{{60, 90}},
289+
},
290+
"The three closest neighbors to a point that is in the KD tree.": {
291+
input: input{p: [2]int{70, 70}, k: 3},
292+
expected: []Tensor2D{{50, 50}, {60, 90}, {70, 70}},
293+
},
294+
}
295+
for name, st := range testTable {
296+
t.Run(name, func(t *testing.T) {
297+
nns := tree.KNearestNeighbor(st.input.p, st.input.k)
298+
slices.SortFunc(nns, tensor2DSortFunc)
299+
slices.SortFunc(st.expected, tensor2DSortFunc)
300+
for i := range len(st.expected) {
301+
if st.expected[i][0] != nns[i][0] || st.expected[i][1] != nns[i][1] {
302+
t.Fatalf("Expected point: %v, got %v", st.expected[i], nns[i])
303+
}
304+
}
305+
})
306+
}
307+
}
308+
254309
func Test2DNodeAddition1(t *testing.T) {
255310
const dimensions = 2
256311
ps := []Tensor2D{

priorityQueue.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package kdtree
2+
3+
import (
4+
"container/heap"
5+
"slices"
6+
)
7+
8+
type Item[T Comparable[T]] struct {
9+
Data T
10+
Priority int
11+
}
12+
13+
type BoundedPriorityQueue[T Comparable[T]] struct {
14+
data []Item[T]
15+
capacity int
16+
}
17+
18+
func NewBoundedPriorityQueue[T Comparable[T]](maxSize int) BoundedPriorityQueue[T] {
19+
return BoundedPriorityQueue[T]{
20+
data: nil,
21+
capacity: maxSize,
22+
}
23+
}
24+
25+
func (pq BoundedPriorityQueue[T]) Len() int { return len(pq.data) }
26+
27+
func (pq BoundedPriorityQueue[T]) Less(i, j int) bool {
28+
// We want Pop to give us the highest, not lowest, priority so we use greater than here.
29+
return pq.data[i].Priority > pq.data[j].Priority
30+
}
31+
32+
func (pq BoundedPriorityQueue[T]) Swap(i, j int) {
33+
pq.data[i], pq.data[j] = pq.data[j], pq.data[i]
34+
}
35+
36+
func (pq *BoundedPriorityQueue[T]) Push(value any) {
37+
item := value.(Item[T])
38+
isFull := pq.Len() == pq.capacity
39+
40+
if isFull {
41+
if pq.data[0].Priority <= item.Priority {
42+
return
43+
}
44+
heap.Pop(pq)
45+
}
46+
pq.data = append(pq.data, item)
47+
}
48+
49+
func (pq *BoundedPriorityQueue[T]) Pop() any {
50+
n := pq.Len()
51+
item := pq.data[n-1]
52+
pq.data = slices.Delete(pq.data, n-1, n)
53+
return item
54+
}
55+
56+
func (pq *BoundedPriorityQueue[T]) Peek() any {
57+
return pq.data[0]
58+
}
59+
60+
func (pq *BoundedPriorityQueue[T]) Capacity() int {
61+
return pq.capacity
62+
}

priorityQueue_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package kdtree_test
2+
3+
import (
4+
"container/heap"
5+
"testing"
6+
7+
kdtree "github.com/rishitc/go-kd-tree"
8+
)
9+
10+
// This example creates a PriorityQueue with some items, adds and manipulates an item,
11+
// and then removes the items in priority order.
12+
func TestBoundedPriorityQueue(t *testing.T) {
13+
table := []struct {
14+
name string
15+
inputElems []kdtree.Item[Tensor2D] // Some items and their priorities.
16+
inputMaxSize int
17+
expLen int
18+
expOrderedElems []Tensor2D
19+
}{
20+
{
21+
name: "More number of input elements as compared to the max allowed size of the bounded priority queue",
22+
inputElems: []kdtree.Item[Tensor2D]{
23+
{Data: Tensor2D{1, 2}, Priority: 3},
24+
{Data: Tensor2D{3, 5}, Priority: 2},
25+
{Data: Tensor2D{6, 7}, Priority: 4},
26+
},
27+
inputMaxSize: 2,
28+
expLen: 2,
29+
expOrderedElems: []Tensor2D{
30+
{1, 2},
31+
{3, 5},
32+
},
33+
},
34+
{
35+
name: "Less number of input elements as compared to the max allowed size of the bounded priority queue",
36+
inputElems: []kdtree.Item[Tensor2D]{
37+
{Data: Tensor2D{1, 2}, Priority: 3},
38+
{Data: Tensor2D{3, 5}, Priority: 2},
39+
{Data: Tensor2D{6, 7}, Priority: 4},
40+
},
41+
inputMaxSize: 5,
42+
expLen: 3,
43+
expOrderedElems: []Tensor2D{
44+
{6, 7},
45+
{1, 2},
46+
{3, 5},
47+
},
48+
},
49+
{
50+
name: "Equal number of input elements as compared to the max allowed size of the bounded priority queue",
51+
inputElems: []kdtree.Item[Tensor2D]{
52+
{Data: Tensor2D{1, 2}, Priority: 3},
53+
{Data: Tensor2D{3, 5}, Priority: 2},
54+
{Data: Tensor2D{6, 7}, Priority: 4},
55+
},
56+
inputMaxSize: 3,
57+
expLen: 3,
58+
expOrderedElems: []Tensor2D{
59+
{6, 7},
60+
{1, 2},
61+
{3, 5},
62+
},
63+
},
64+
{
65+
name: "Inserting two elements with the same priority and the one inserted first should be retained",
66+
inputElems: []kdtree.Item[Tensor2D]{
67+
{Data: Tensor2D{1, 2}, Priority: 2},
68+
{Data: Tensor2D{3, 5}, Priority: 3},
69+
{Data: Tensor2D{4, 8}, Priority: 4},
70+
{Data: Tensor2D{6, 7}, Priority: 4},
71+
},
72+
inputMaxSize: 3,
73+
expLen: 3,
74+
expOrderedElems: []Tensor2D{
75+
{4, 8},
76+
{3, 5},
77+
{1, 2},
78+
},
79+
},
80+
}
81+
82+
for _, st := range table {
83+
t.Run(st.name, func(t *testing.T) {
84+
pq := kdtree.NewBoundedPriorityQueue[Tensor2D](st.inputMaxSize)
85+
for _, elem := range st.inputElems {
86+
data, priority := elem.Data, elem.Priority
87+
heap.Push(&pq, kdtree.Item[Tensor2D]{
88+
Data: data,
89+
Priority: priority,
90+
})
91+
}
92+
93+
if pq.Len() != st.expLen {
94+
t.Errorf("Expected size to be %v, got %v", st.expLen, pq.Len())
95+
}
96+
97+
for i := range st.expLen {
98+
headElem := pq.Peek().(kdtree.Item[Tensor2D]).Data
99+
expElem := st.expOrderedElems[i]
100+
if expElem[0] != headElem[0] && expElem[1] != headElem[1] {
101+
t.Errorf("Expected next element to be %v, got %v", expElem, headElem)
102+
}
103+
heap.Pop(&pq)
104+
}
105+
})
106+
}
107+
}
108+

0 commit comments

Comments
 (0)