Skip to content

Commit ddb076c

Browse files
committed
feat: ✨ add an API to find the point with the max value in a user given dimension
1 parent 222fb8c commit ddb076c

File tree

3 files changed

+113
-0
lines changed

3 files changed

+113
-0
lines changed

comparable_utils.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ func min[T Comparable[T]](lhs, rhs *T, tcd int) *T {
77
return rhs
88
}
99

10+
func max[T Comparable[T]](lhs, rhs *T, tcd int) *T {
11+
if (*lhs).Order(*rhs, tcd) == Greater {
12+
return lhs
13+
}
14+
return rhs
15+
}
16+
1017
func distance[T Comparable[T]](src, dst *T) int {
1118
return (*src).Dist(*dst)
1219
}

kdtree.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ func (t *KDTree[T]) FindMin(targetDimension int) (T, bool) {
8282
return *res, true
8383
}
8484

85+
func (t *KDTree[T]) FindMax(targetDimension int) (T, bool) {
86+
if t.root == nil || targetDimension >= t.dimensions {
87+
return t.zeroVal, false
88+
}
89+
res := findMax(t.dimensions, targetDimension, 0, t.root)
90+
if res == nil {
91+
return t.zeroVal, false
92+
}
93+
return *res, true
94+
}
95+
8596
func (t *KDTree[T]) NearestNeighbor(value T) (T, bool) {
8697
res := nearestNeighbor(t.dimensions, &value, nil, 0, t.root)
8798
if res == nil {
@@ -426,3 +437,32 @@ func findMin[T Comparable[T]](d, tcd, cd int, r *kdNode[T]) *T {
426437
return min(lMin, min(rMin, &r.value, tcd), tcd)
427438
}
428439
}
440+
441+
func findMax[T Comparable[T]](d, tcd, cd int, r *kdNode[T]) *T {
442+
if r == nil {
443+
return nil
444+
}
445+
446+
var lMax *T
447+
var rMax *T
448+
ncd := (cd + 1) % d
449+
rMax = findMax(d, tcd, ncd, r.right)
450+
if tcd != cd {
451+
lMax = findMax(d, tcd, ncd, r.left)
452+
}
453+
if lMax == nil && rMax == nil {
454+
return &r.value
455+
} else if lMax == nil {
456+
if (*rMax).Order(r.value, tcd) == Greater {
457+
return rMax
458+
}
459+
return &r.value
460+
} else if rMax == nil {
461+
if (*lMax).Order(r.value, tcd) == Greater {
462+
return lMax
463+
}
464+
return &r.value
465+
} else {
466+
return max(lMax, max(rMax, &r.value, tcd), tcd)
467+
}
468+
}

kdtree_2d_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,72 @@ func Test2DFindMin2(t *testing.T) {
374374
}
375375
}
376376

377+
func Test2DFindMax1(t *testing.T) {
378+
const dimensions = 2
379+
ps := []Tensor2D{
380+
{35, 90},
381+
{60, 80},
382+
{51, 75},
383+
{70, 70},
384+
{50, 50},
385+
{25, 40},
386+
{10, 30},
387+
{1, 10},
388+
{55, 1},
389+
}
390+
tree := kdtree.NewKDTreeWithValues(dimensions, ps)
391+
testTable := []struct {
392+
input int
393+
expected Tensor2D
394+
}{
395+
{
396+
input: 0,
397+
expected: Tensor2D{70, 70},
398+
},
399+
{
400+
input: 1,
401+
expected: Tensor2D{35, 90},
402+
},
403+
}
404+
for _, v := range testTable {
405+
nn, ok := tree.FindMax(v.input)
406+
if !ok || !slices.Equal(nn[:], v.expected[:]) {
407+
t.Fatalf("Expected closest point: %v, got %v", v.expected, nn)
408+
}
409+
}
410+
}
411+
412+
func Test2DFindMax2(t *testing.T) {
413+
const dimensions = 2
414+
ps := []Tensor2D{
415+
{35, 90},
416+
{60, 80},
417+
{51, 75},
418+
{70, 70},
419+
{50, 50},
420+
{25, 40},
421+
{10, 30},
422+
{1, 10},
423+
{55, 1},
424+
}
425+
tree := kdtree.NewKDTreeWithValues(dimensions, ps)
426+
testTable := []struct {
427+
input int
428+
expected Tensor2D
429+
}{
430+
{
431+
input: 1,
432+
expected: Tensor2D{55, 1},
433+
},
434+
}
435+
for _, v := range testTable {
436+
nn, ok := tree.FindMin(v.input)
437+
if !ok || !slices.Equal(nn[:], v.expected[:]) {
438+
t.Fatalf("Expected closest point: %v, got %v", v.expected, nn)
439+
}
440+
}
441+
}
442+
377443
func Test2DDeleteAllNodesInTree(t *testing.T) {
378444
treeNodes := kdtree.NewKDNode(Tensor2D{5, 6})
379445
tree := kdtree.NewTestKDTree(2, treeNodes)

0 commit comments

Comments
 (0)