Skip to content

Commit 145a479

Browse files
committed
perf!: ⚡ improve the time and space efficiency of the KNearestNeighbor API
The fundamental changes made to improve the time and space efficiency of the `KNearestNeighbor` API are: 1. Replace the interface-based implementation of the priority queue with a generics-based implementation. 2. Preallocate an underlying slice with a capacity equal to the maximum allowed size of the priority queue. 3. The `Item` struct now stores a pointer to the original data instead of storing a copy of the original data. 4. Return a slice of pointers to the k-nearest neighbours (i.e., `[]*T`) instead of a slice of copies of the k-nearest neighbours' values (i.e., `[]T`). These changes have led to a remarkable ~34.26% enhancement in performance (i.e., `ns/op`). The memory efficiency has also substantially improved, with ~5 times less `B/op` and ~113 times less `allocs/op`. BREAKING CHANGE: To improve the performance of the `KNearestNeighbor` API, I changed its return type from `[]T` to `[]*T` to reduce the number of costly data copies that occur. Utilizing pointers to access the actual data is a significantly more efficient approach in terms of CPU and memory usage. With just a single level of indirection, we can save on a considerable amount of data copying, ensuring a more streamlined implementation. With this change to the `KNearestNeighbor` API's return type in this commit, any code that uses the `KNearestNeighbor` API will no longer work with future releases of this library created after this commit. Affected codebases must migrate to the updated version of the `KNearestNeighbor` API to be compatible with future releases of this library.
1 parent f9956c4 commit 145a479

File tree

5 files changed

+116
-35
lines changed

5 files changed

+116
-35
lines changed

internal/utils/generic_pq.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// This file is a rewrite of Go's "container/heap" implementation with the usages of interfaces replaced with generics.
2+
// Link to Go's "container/heap" implementation: https://cs.opensource.google/go/go/+/refs/tags/go1.23.0:src/container/heap/heap.go
3+
4+
package internal
5+
6+
import "sort"
7+
8+
type Interface[T any] interface {
9+
sort.Interface
10+
Push(x T) // add x as element Len()
11+
Pop() T // remove and return element Len() - 1.
12+
}
13+
14+
func Init[V any, T Interface[V]](h T) {
15+
// heapify the input
16+
n := h.Len()
17+
for i := n/2 - 1; i >= 0; i-- {
18+
down(h, i, n)
19+
}
20+
}
21+
22+
func Push[V any, T Interface[V]](h T, x V) {
23+
h.Push(x)
24+
up(h, h.Len()-1)
25+
}
26+
27+
func Pop[V any, T Interface[V]](h T) V {
28+
n := h.Len() - 1
29+
h.Swap(0, n)
30+
down(h, 0, n)
31+
return h.Pop()
32+
}
33+
34+
func Remove[V any, T Interface[V]](h T, i int) any {
35+
n := h.Len() - 1
36+
if n != i {
37+
h.Swap(i, n)
38+
if !down(h, i, n) {
39+
up(h, i)
40+
}
41+
}
42+
return h.Pop()
43+
}
44+
45+
func Fix[V any, T Interface[V]](h T, i int) {
46+
if !down(h, i, h.Len()) {
47+
up(h, i)
48+
}
49+
}
50+
51+
func up[V any, T Interface[V]](h T, j int) {
52+
for {
53+
i := (j - 1) / 2 // parent
54+
if i == j || !h.Less(j, i) {
55+
break
56+
}
57+
h.Swap(i, j)
58+
j = i
59+
}
60+
}
61+
62+
func down[V any, T Interface[V]](h T, i0, n int) bool {
63+
i := i0
64+
for {
65+
j1 := 2*i + 1
66+
if j1 >= n || j1 < 0 { // j1 < 0 after int overflow
67+
break
68+
}
69+
j := j1 // left child
70+
if j2 := j1 + 1; j2 < n && h.Less(j2, j1) {
71+
j = j2 // = 2*i + 2 // right child
72+
}
73+
if !h.Less(j, i) {
74+
break
75+
}
76+
h.Swap(i, j)
77+
i = j
78+
}
79+
return i > i0
80+
}

kdtree.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package kdtree
22

33
import (
4-
"container/heap"
54
"fmt"
65
"sort"
76
"strings"
@@ -403,17 +402,18 @@ func nearestNeighbor[T Comparable[T]](d int, v, nn *T, cd int, r *kdNode[T]) *T
403402
return nn
404403
}
405404

406-
func (t *KDTree[T]) KNearestNeighbor(value T, k int) []T {
405+
func (t *KDTree[T]) KNearestNeighbor(value T, k int) []*T {
407406
if t == nil || t.root == nil || t.size < k {
408407
return nil
409408
}
410409

411410
pqRes := NewBoundedPriorityQueue[T](k)
412411
kNearestNeighbor(k, t.dimensions, &value, &pqRes, 0, t.root)
413412

414-
res := make([]T, 0, pqRes.Len())
415-
for pqRes.Len() != 0 {
416-
d := heap.Pop(&pqRes).(Item[T]).Data
413+
res := make([]*T, 0, k)
414+
for range k {
415+
// heap with a preset capacity
416+
d := internal.Pop(&pqRes).Data
417417
res = append(res, d)
418418
}
419419

@@ -459,8 +459,8 @@ func kNearestNeighbor[T Comparable[T]](k, d int, v *T, pq *BoundedPriorityQueue[
459459
ncd = (ncd - 1 + d) % d // Go back to the dimension used for splitting at the leaf node.
460460
for path, cn, cDir := popLast(path); cn != nil; path, cn, cDir = popLast(path) {
461461
currentDistance := (*v).Dist(cn.value)
462-
heap.Push(pq, Item[T]{
463-
Data: cn.value,
462+
internal.Push(pq, Item[T]{
463+
Data: &cn.value,
464464
Priority: currentDistance,
465465
})
466466

@@ -478,7 +478,7 @@ func kNearestNeighbor[T Comparable[T]](k, d int, v *T, pq *BoundedPriorityQueue[
478478
}
479479

480480
func getFarthestDistance[T Comparable[T]](pq *BoundedPriorityQueue[T]) int {
481-
v := pq.Peek().(Item[T])
481+
v := pq.Peek()
482482
return v.Priority
483483
}
484484

priorityQueue.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
package kdtree
22

33
import (
4-
"container/heap"
54
"slices"
5+
6+
heap "github.com/rishitc/go-kd-tree/internal/utils"
67
)
78

89
type Item[T Comparable[T]] struct {
9-
Data T
10+
Data *T
1011
Priority int
1112
}
1213

@@ -17,7 +18,7 @@ type BoundedPriorityQueue[T Comparable[T]] struct {
1718

1819
func NewBoundedPriorityQueue[T Comparable[T]](maxSize int) BoundedPriorityQueue[T] {
1920
return BoundedPriorityQueue[T]{
20-
data: nil,
21+
data: make([]Item[T], 0, maxSize),
2122
capacity: maxSize,
2223
}
2324
}
@@ -33,8 +34,7 @@ func (pq BoundedPriorityQueue[T]) Swap(i, j int) {
3334
pq.data[i], pq.data[j] = pq.data[j], pq.data[i]
3435
}
3536

36-
func (pq *BoundedPriorityQueue[T]) Push(value any) {
37-
item := value.(Item[T])
37+
func (pq *BoundedPriorityQueue[T]) Push(item Item[T]) {
3838
isFull := pq.Len() == pq.capacity
3939

4040
if isFull {
@@ -46,14 +46,14 @@ func (pq *BoundedPriorityQueue[T]) Push(value any) {
4646
pq.data = append(pq.data, item)
4747
}
4848

49-
func (pq *BoundedPriorityQueue[T]) Pop() any {
49+
func (pq *BoundedPriorityQueue[T]) Pop() Item[T] {
5050
n := pq.Len()
5151
item := pq.data[n-1]
5252
pq.data = slices.Delete(pq.data, n-1, n)
5353
return item
5454
}
5555

56-
func (pq *BoundedPriorityQueue[T]) Peek() any {
56+
func (pq *BoundedPriorityQueue[T]) Peek() Item[T] {
5757
return pq.data[0]
5858
}
5959

tests/kdtree_2d_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ func Test2DNearestNeighbor5(t *testing.T) {
184184
}
185185
}
186186

187-
func tensor2DSortFunc(a, b types.Tensor2D) int {
187+
func tensor2DSortFunc(a, b *types.Tensor2D) int {
188188
if a[0] != b[0] {
189189
return a[0] - b[0]
190190
} else {
@@ -210,19 +210,19 @@ func Test2DKNearestNeighbor1(t *testing.T) {
210210
tree := kdtree.NewKDTreeWithValues(dimensions, ps)
211211
testTable := map[string]struct {
212212
input input
213-
expected []types.Tensor2D
213+
expected []*types.Tensor2D
214214
}{
215215
"Find the 2 closest neighbors to a point that is not in the KD tree.": {
216216
input: input{p: [2]int{25, 25}, k: 2},
217-
expected: []types.Tensor2D{{40, 20}, {10, 25}},
217+
expected: []*types.Tensor2D{{40, 20}, {10, 25}},
218218
},
219219
"The closest neighbor to a point that is in the KD tree.": {
220220
input: input{p: [2]int{60, 90}, k: 1},
221-
expected: []types.Tensor2D{{60, 90}},
221+
expected: []*types.Tensor2D{{60, 90}},
222222
},
223223
"The three closest neighbors to a point that is in the KD tree.": {
224224
input: input{p: [2]int{70, 70}, k: 3},
225-
expected: []types.Tensor2D{{50, 50}, {60, 90}, {70, 70}},
225+
expected: []*types.Tensor2D{{50, 50}, {60, 90}, {70, 70}},
226226
},
227227
}
228228
for name, st := range testTable {

tests/priorityQueue_test.go

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
package tests
22

33
import (
4-
"container/heap"
54
"testing"
65

6+
heap "github.com/rishitc/go-kd-tree/internal/utils"
7+
78
boundedpq "github.com/rishitc/go-kd-tree"
89
types "github.com/rishitc/go-kd-tree/internal/types"
910
)
@@ -19,9 +20,9 @@ func TestBoundedPriorityQueue(t *testing.T) {
1920
{
2021
name: "More number of input elements as compared to the max allowed size of the bounded priority queue",
2122
inputElems: []boundedpq.Item[types.Tensor2D]{
22-
{Data: types.Tensor2D{1, 2}, Priority: 3},
23-
{Data: types.Tensor2D{3, 5}, Priority: 2},
24-
{Data: types.Tensor2D{6, 7}, Priority: 4},
23+
{Data: &types.Tensor2D{1, 2}, Priority: 3},
24+
{Data: &types.Tensor2D{3, 5}, Priority: 2},
25+
{Data: &types.Tensor2D{6, 7}, Priority: 4},
2526
},
2627
inputMaxSize: 2,
2728
expLen: 2,
@@ -33,9 +34,9 @@ func TestBoundedPriorityQueue(t *testing.T) {
3334
{
3435
name: "Less number of input elements as compared to the max allowed size of the bounded priority queue",
3536
inputElems: []boundedpq.Item[types.Tensor2D]{
36-
{Data: types.Tensor2D{1, 2}, Priority: 3},
37-
{Data: types.Tensor2D{3, 5}, Priority: 2},
38-
{Data: types.Tensor2D{6, 7}, Priority: 4},
37+
{Data: &types.Tensor2D{1, 2}, Priority: 3},
38+
{Data: &types.Tensor2D{3, 5}, Priority: 2},
39+
{Data: &types.Tensor2D{6, 7}, Priority: 4},
3940
},
4041
inputMaxSize: 5,
4142
expLen: 3,
@@ -48,9 +49,9 @@ func TestBoundedPriorityQueue(t *testing.T) {
4849
{
4950
name: "Equal number of input elements as compared to the max allowed size of the bounded priority queue",
5051
inputElems: []boundedpq.Item[types.Tensor2D]{
51-
{Data: types.Tensor2D{1, 2}, Priority: 3},
52-
{Data: types.Tensor2D{3, 5}, Priority: 2},
53-
{Data: types.Tensor2D{6, 7}, Priority: 4},
52+
{Data: &types.Tensor2D{1, 2}, Priority: 3},
53+
{Data: &types.Tensor2D{3, 5}, Priority: 2},
54+
{Data: &types.Tensor2D{6, 7}, Priority: 4},
5455
},
5556
inputMaxSize: 3,
5657
expLen: 3,
@@ -63,10 +64,10 @@ func TestBoundedPriorityQueue(t *testing.T) {
6364
{
6465
name: "Inserting two elements with the same priority and the one inserted first should be retained",
6566
inputElems: []boundedpq.Item[types.Tensor2D]{
66-
{Data: types.Tensor2D{1, 2}, Priority: 2},
67-
{Data: types.Tensor2D{3, 5}, Priority: 3},
68-
{Data: types.Tensor2D{4, 8}, Priority: 4},
69-
{Data: types.Tensor2D{6, 7}, Priority: 4},
67+
{Data: &types.Tensor2D{1, 2}, Priority: 2},
68+
{Data: &types.Tensor2D{3, 5}, Priority: 3},
69+
{Data: &types.Tensor2D{4, 8}, Priority: 4},
70+
{Data: &types.Tensor2D{6, 7}, Priority: 4},
7071
},
7172
inputMaxSize: 3,
7273
expLen: 3,
@@ -94,7 +95,7 @@ func TestBoundedPriorityQueue(t *testing.T) {
9495
}
9596

9697
for i := range st.expLen {
97-
headElem := pq.Peek().(boundedpq.Item[types.Tensor2D]).Data
98+
headElem := pq.Peek().Data
9899
expElem := st.expOrderedElems[i]
99100
if expElem[0] != headElem[0] && expElem[1] != headElem[1] {
100101
t.Errorf("Expected next element to be %v, got %v", expElem, headElem)

0 commit comments

Comments
 (0)