Skip to content

Commit d3eadbd

Browse files
committed
feat: ✨ add an API to rebalance the kd-tree on demand
1 parent ddb076c commit d3eadbd

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

kdtree.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,11 @@ func (t *KDTree[T]) Encode() []byte {
221221
return builder.FinishedBytes()
222222
}
223223

224+
// Balance rebalance the k-d tree by recreating it.
225+
func (t *KDTree[T]) Balance() {
226+
t.root = NewKDTreeWithValues(t.dimensions, t.Values()).root
227+
}
228+
224229
func query[T Comparable[T]](getRelativePosition func(T, int) RelativePosition, d int, res *[]T, r *kdNode[T], cd int) {
225230
if r == nil {
226231
return

kdtree_2d_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,3 +765,55 @@ func Test2DTreeEncodeDecode(t *testing.T) {
765765
t.Fatalf("Tree does not match expected tree structure\nExpected:\n%s\nGot:\n%s", expectedTree, tree)
766766
}
767767
}
768+
769+
func Test2DTreeBalance(t *testing.T) {
770+
treeNodes := kdtree.NewKDNode(Tensor2D{20, 15}).
771+
SetLeft(
772+
kdtree.NewKDNode(Tensor2D{3, 25}),
773+
).
774+
SetRight(
775+
kdtree.NewKDNode(Tensor2D{30, 40}).
776+
SetRight(
777+
kdtree.NewKDNode(Tensor2D{25, 50}).
778+
SetRight(
779+
kdtree.NewKDNode(Tensor2D{28, 47}).
780+
SetRight(
781+
kdtree.NewKDNode(Tensor2D{40, 60}),
782+
),
783+
),
784+
),
785+
)
786+
tree1 := kdtree.NewTestKDTree(2, treeNodes)
787+
788+
expTreeNodes1 := kdtree.NewKDNode(Tensor2D{25, 50}).
789+
SetLeft(
790+
kdtree.NewKDNode(Tensor2D{20, 15}).
791+
SetRight(
792+
kdtree.NewKDNode(Tensor2D{3, 25}),
793+
),
794+
).
795+
SetRight(
796+
kdtree.NewKDNode(Tensor2D{28, 47}).
797+
SetLeft(
798+
kdtree.NewKDNode(Tensor2D{30, 40}),
799+
).
800+
SetRight(
801+
kdtree.NewKDNode(Tensor2D{40, 60}),
802+
),
803+
)
804+
testTable := []struct {
805+
input *kdtree.KDTree[Tensor2D]
806+
expected *kdtree.KDTree[Tensor2D]
807+
}{
808+
{
809+
input: tree1,
810+
expected: kdtree.NewTestKDTree(2, expTreeNodes1),
811+
},
812+
}
813+
for _, v := range testTable {
814+
v.input.Balance()
815+
if !kdtree.IdenticalTrees(v.input, v.expected) {
816+
t.Fatalf("Tree does not match expected tree structure\nExpected:\n%s\nGot:\n%s", v.expected, v.input)
817+
}
818+
}
819+
}

0 commit comments

Comments
 (0)