@@ -765,3 +765,55 @@ func Test2DTreeEncodeDecode(t *testing.T) {
765
765
t .Fatalf ("Tree does not match expected tree structure\n Expected:\n %s\n Got:\n %s" , expectedTree , tree )
766
766
}
767
767
}
768
+
769
+ func Test2DTreeBalance (t * testing.T ) {
770
+ treeNodes := kdtree .NewKDNode (Tensor2D {20 , 15 }).
771
+ SetLeft (
772
+ kdtree .NewKDNode (Tensor2D {3 , 25 }),
773
+ ).
774
+ SetRight (
775
+ kdtree .NewKDNode (Tensor2D {30 , 40 }).
776
+ SetRight (
777
+ kdtree .NewKDNode (Tensor2D {25 , 50 }).
778
+ SetRight (
779
+ kdtree .NewKDNode (Tensor2D {28 , 47 }).
780
+ SetRight (
781
+ kdtree .NewKDNode (Tensor2D {40 , 60 }),
782
+ ),
783
+ ),
784
+ ),
785
+ )
786
+ tree1 := kdtree .NewTestKDTree (2 , treeNodes )
787
+
788
+ expTreeNodes1 := kdtree .NewKDNode (Tensor2D {25 , 50 }).
789
+ SetLeft (
790
+ kdtree .NewKDNode (Tensor2D {20 , 15 }).
791
+ SetRight (
792
+ kdtree .NewKDNode (Tensor2D {3 , 25 }),
793
+ ),
794
+ ).
795
+ SetRight (
796
+ kdtree .NewKDNode (Tensor2D {28 , 47 }).
797
+ SetLeft (
798
+ kdtree .NewKDNode (Tensor2D {30 , 40 }),
799
+ ).
800
+ SetRight (
801
+ kdtree .NewKDNode (Tensor2D {40 , 60 }),
802
+ ),
803
+ )
804
+ testTable := []struct {
805
+ input * kdtree.KDTree [Tensor2D ]
806
+ expected * kdtree.KDTree [Tensor2D ]
807
+ }{
808
+ {
809
+ input : tree1 ,
810
+ expected : kdtree .NewTestKDTree (2 , expTreeNodes1 ),
811
+ },
812
+ }
813
+ for _ , v := range testTable {
814
+ v .input .Balance ()
815
+ if ! kdtree .IdenticalTrees (v .input , v .expected ) {
816
+ t .Fatalf ("Tree does not match expected tree structure\n Expected:\n %s\n Got:\n %s" , v .expected , v .input )
817
+ }
818
+ }
819
+ }
0 commit comments