|
7 | 7 | "testing"
|
8 | 8 |
|
9 | 9 | kdtree "github.com/rishitc/go-kd-tree"
|
| 10 | + "github.com/stretchr/testify/assert" |
10 | 11 | )
|
11 | 12 |
|
12 | 13 | type Tensor2D [2]int
|
@@ -817,3 +818,175 @@ func Test2DTreeBalance(t *testing.T) {
|
817 | 818 | }
|
818 | 819 | }
|
819 | 820 | }
|
| 821 | + |
| 822 | +func Test2DTree_Query(t *testing.T) { |
| 823 | + const dimensions = 2 |
| 824 | + inputTensor2D := []Tensor2D{{1, 0}, {1, 8}, {2, 2}, {2, 10}, {3, 4}, {4, 1}, {5, 4}, {6, 8}, {7, 4}, {7, 7}, {8, 2}, {8, 5}, {9, 9}, {3, 6}, {4, 2}, {9, 2}, {6, 5}, {3, 8}, {6, 2}, {1, 3}, {3, 3}, {6, 4}, {9, 8}, {2, 1}, {2, 8}, {3, 1}, {7, 3}, {3, 9}, {4, 4}, {5, 3}, {9, 6}} |
| 825 | + tests := []struct { |
| 826 | + name string |
| 827 | + tree *kdtree.KDTree[Tensor2D] |
| 828 | + input kdtree.RangeFunc[Tensor2D] |
| 829 | + expected []Tensor2D |
| 830 | + }{ |
| 831 | + { |
| 832 | + name: "out of range x (lower)", |
| 833 | + tree: kdtree.NewKDTreeWithValues(dimensions, inputTensor2D), |
| 834 | + input: func(td Tensor2D, i int) kdtree.RelativePosition { |
| 835 | + switch i { |
| 836 | + case -1: |
| 837 | + if x, y := td[0], td[1]; -2 <= x && x < -1 && 2 <= y && y < 10 { |
| 838 | + return kdtree.InRange |
| 839 | + } |
| 840 | + return kdtree.AfterRange |
| 841 | + case 0: |
| 842 | + if x := td[0]; x < -2 { |
| 843 | + return kdtree.BeforeRange |
| 844 | + } else if x >= -1 { |
| 845 | + return kdtree.AfterRange |
| 846 | + } else { |
| 847 | + return kdtree.InRange |
| 848 | + } |
| 849 | + case 1: |
| 850 | + if y := td[1]; y < 2 { |
| 851 | + return kdtree.BeforeRange |
| 852 | + } else if y >= 10 { |
| 853 | + return kdtree.AfterRange |
| 854 | + } else { |
| 855 | + return kdtree.InRange |
| 856 | + } |
| 857 | + } |
| 858 | + return kdtree.AfterRange |
| 859 | + }, |
| 860 | + expected: []Tensor2D(nil), |
| 861 | + }, |
| 862 | + { |
| 863 | + name: "out of range y (lower)", |
| 864 | + tree: kdtree.NewKDTreeWithValues(dimensions, inputTensor2D), |
| 865 | + input: func(td Tensor2D, i int) kdtree.RelativePosition { |
| 866 | + switch i { |
| 867 | + case -1: |
| 868 | + if x, y := td[0], td[1]; 2 <= x && x < 10 && -2 <= y && y < -1 { |
| 869 | + return kdtree.InRange |
| 870 | + } |
| 871 | + return kdtree.AfterRange |
| 872 | + case 0: |
| 873 | + if x := td[0]; x < 2 { |
| 874 | + return kdtree.BeforeRange |
| 875 | + } else if x >= 10 { |
| 876 | + return kdtree.AfterRange |
| 877 | + } else { |
| 878 | + return kdtree.InRange |
| 879 | + } |
| 880 | + case 1: |
| 881 | + if y := td[1]; y < -2 { |
| 882 | + return kdtree.BeforeRange |
| 883 | + } else if y >= -1 { |
| 884 | + return kdtree.AfterRange |
| 885 | + } else { |
| 886 | + return kdtree.InRange |
| 887 | + } |
| 888 | + } |
| 889 | + return kdtree.AfterRange |
| 890 | + }, |
| 891 | + expected: []Tensor2D(nil), |
| 892 | + }, |
| 893 | + { |
| 894 | + name: "out of range x (higher)", |
| 895 | + tree: kdtree.NewKDTreeWithValues(dimensions, inputTensor2D), |
| 896 | + input: func(td Tensor2D, i int) kdtree.RelativePosition { |
| 897 | + switch i { |
| 898 | + case -1: |
| 899 | + if x, y := td[0], td[1]; 20 <= x && x < 30 && 2 <= y && y < 10 { |
| 900 | + return kdtree.InRange |
| 901 | + } |
| 902 | + return kdtree.AfterRange |
| 903 | + case 0: |
| 904 | + if x := td[0]; x < 20 { |
| 905 | + return kdtree.BeforeRange |
| 906 | + } else if x >= 30 { |
| 907 | + return kdtree.AfterRange |
| 908 | + } else { |
| 909 | + return kdtree.InRange |
| 910 | + } |
| 911 | + case 1: |
| 912 | + if y := td[1]; y < 2 { |
| 913 | + return kdtree.BeforeRange |
| 914 | + } else if y >= 10 { |
| 915 | + return kdtree.AfterRange |
| 916 | + } else { |
| 917 | + return kdtree.InRange |
| 918 | + } |
| 919 | + } |
| 920 | + return kdtree.AfterRange |
| 921 | + }, |
| 922 | + expected: []Tensor2D(nil), |
| 923 | + }, |
| 924 | + { |
| 925 | + name: "out of range y (higher)", |
| 926 | + tree: kdtree.NewKDTreeWithValues(dimensions, inputTensor2D), |
| 927 | + input: func(td Tensor2D, i int) kdtree.RelativePosition { |
| 928 | + switch i { |
| 929 | + case -1: |
| 930 | + if x, y := td[0], td[1]; 2 <= x && x < 10 && 20 <= y && y < 30 { |
| 931 | + return kdtree.InRange |
| 932 | + } |
| 933 | + return kdtree.AfterRange |
| 934 | + case 0: |
| 935 | + if x := td[0]; x < 2 { |
| 936 | + return kdtree.BeforeRange |
| 937 | + } else if x >= 10 { |
| 938 | + return kdtree.AfterRange |
| 939 | + } else { |
| 940 | + return kdtree.InRange |
| 941 | + } |
| 942 | + case 1: |
| 943 | + if y := td[1]; y < 20 { |
| 944 | + return kdtree.BeforeRange |
| 945 | + } else if y >= 30 { |
| 946 | + return kdtree.AfterRange |
| 947 | + } else { |
| 948 | + return kdtree.InRange |
| 949 | + } |
| 950 | + } |
| 951 | + return kdtree.AfterRange |
| 952 | + }, |
| 953 | + expected: []Tensor2D(nil), |
| 954 | + }, |
| 955 | + { |
| 956 | + name: "some values in range", |
| 957 | + tree: kdtree.NewKDTreeWithValues(dimensions, inputTensor2D), |
| 958 | + input: func(td Tensor2D, i int) kdtree.RelativePosition { |
| 959 | + switch i { |
| 960 | + case -1: |
| 961 | + if x, y := td[0], td[1]; 1 <= x && x < 2 && 2 <= y && y < 10 { |
| 962 | + return kdtree.InRange |
| 963 | + } |
| 964 | + return kdtree.AfterRange |
| 965 | + case 0: |
| 966 | + if x := td[0]; x < -2 { |
| 967 | + return kdtree.BeforeRange |
| 968 | + } else if x >= -1 { |
| 969 | + return kdtree.AfterRange |
| 970 | + } else { |
| 971 | + return kdtree.InRange |
| 972 | + } |
| 973 | + case 1: |
| 974 | + if y := td[1]; y < 2 { |
| 975 | + return kdtree.BeforeRange |
| 976 | + } else if y >= 10 { |
| 977 | + return kdtree.AfterRange |
| 978 | + } else { |
| 979 | + return kdtree.InRange |
| 980 | + } |
| 981 | + } |
| 982 | + return kdtree.AfterRange |
| 983 | + }, |
| 984 | + expected: []Tensor2D{{1, 3}, {1, 8}}, |
| 985 | + }, |
| 986 | + } |
| 987 | + for _, test := range tests { |
| 988 | + t.Run(test.name, func(t *testing.T) { |
| 989 | + assert.Equal(t, test.expected, test.tree.Query(test.input)) |
| 990 | + }) |
| 991 | + } |
| 992 | +} |
0 commit comments