Skip to content

Commit d952825

Browse files
committed
kNN: Add heap logic and sort requirements
1 parent 76052e0 commit d952825

File tree

3 files changed

+78
-65
lines changed

3 files changed

+78
-65
lines changed

benchmarks/automatic/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -668,8 +668,8 @@ BENCHMARK(Benchmarks::Point::InsertUnique)->Arg(10)->Arg(20)->Arg(50)->Arg(100)-
668668
BENCHMARK(Benchmarks::Point::Update)->Arg(10)->Arg(20)->Arg(50)->Arg(100)->Arg(1000)->Arg(10000)->Unit(benchmark::kMillisecond);
669669
BENCHMARK(Benchmarks::Point::Contains)->Arg(1000)->Arg(10000);
670670
BENCHMARK(Benchmarks::Point::RangeSearch)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000);
671-
BENCHMARK(Benchmarks::Point::GetNearestNeighbors)->Arg(1000)->Arg(10000)->Unit(benchmark::kNanosecond);
672-
BENCHMARK(Benchmarks::Point::GetNearestNeighbors<6>)->Arg(1000)->Arg(10000)->Unit(benchmark::kNanosecond);
671+
BENCHMARK(Benchmarks::Point::GetNearestNeighbors)->Arg(1000)->Arg(10000)->Unit(benchmark::kMillisecond);
672+
BENCHMARK(Benchmarks::Point::GetNearestNeighbors<6>)->Arg(1000)->Arg(10000)->Unit(benchmark::kMillisecond);
673673
BENCHMARK(Benchmarks::Point::GetNearestNeighbors<63>)->Arg(1000)->Arg(10000)->Unit(benchmark::kMillisecond);
674674
BENCHMARK(Benchmarks::Point::FrustumCulling)->Arg(1000)->Arg(10000)->Unit(benchmark::kMillisecond);
675675
BENCHMARK(Benchmarks::Box::Create<3, 0, false>)->Arg(10)->Arg(20)->Arg(50)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000)->Arg(1000000)->Unit(benchmark::kMillisecond);

octree.h

Lines changed: 73 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3660,48 +3660,57 @@ namespace OrthoTree
36603660

36613661
private: // K Nearest Neighbor helpers
36623662
static inline void AddEntityDistance(
3663-
auto const& entities, TVector const& searchPoint, TContainer const& points, std::vector<EntityDistance>& neighborEntities, TGeometry maxDistance) noexcept
3663+
auto const& entities,
3664+
TVector const& searchPoint,
3665+
TContainer const& points,
3666+
std::vector<EntityDistance>& neighborEntities,
3667+
std::size_t neighborNo,
3668+
TGeometry& maxDistanceWithin) noexcept
36643669
{
36653670
for (auto const entityID : entities)
36663671
{
36673672
const auto distance = AD::Distance(searchPoint, detail::at(points, entityID));
3668-
if (distance < maxDistance)
3669-
{
3673+
3674+
// maxDistanceWithin is implemented for tolerance handling: distance should be less than maxDistanceWithin
3675+
if (distance >= maxDistanceWithin)
3676+
continue;
3677+
3678+
if (neighborEntities.size() < neighborNo - 1)
36703679
neighborEntities.push_back({ { distance }, entityID });
3680+
else
3681+
{
3682+
if (neighborEntities.size() < neighborNo)
3683+
{
3684+
std::make_heap(neighborEntities.begin(), neighborEntities.end());
3685+
neighborEntities.push_back({ { distance }, entityID });
3686+
}
3687+
else
3688+
{
3689+
std::pop_heap(neighborEntities.begin(), neighborEntities.end());
3690+
neighborEntities.back() = { { distance }, entityID };
3691+
}
3692+
std::push_heap(neighborEntities.begin(), neighborEntities.end());
3693+
maxDistanceWithin = neighborEntities.front().Distance;
36713694
}
36723695
}
36733696
}
36743697

3675-
static inline IGM::Geometry GetFarestDistance(std::vector<EntityDistance>& neighborEntities, std::size_t neighborNo, IGM::Geometry maxDistance) noexcept
3676-
{
3677-
if (neighborEntities.size() < neighborNo)
3678-
{
3679-
return maxDistance;
3680-
}
3681-
3682-
auto const farestNeighborID = neighborNo - 1;
3683-
std::nth_element(neighborEntities.begin(), std::next(neighborEntities.begin(), farestNeighborID), neighborEntities.end());
3684-
return neighborEntities[farestNeighborID].Distance;
3685-
}
3686-
3687-
static std::vector<TEntityID> ConvertEntityDistanceToList(std::vector<EntityDistance>& neighborEntities, std::size_t neighborNo) noexcept
3698+
static inline constexpr std::vector<TEntityID> ConvertEntityDistanceToList(std::vector<EntityDistance>& neighborEntities, std::size_t neighborNo) noexcept
36883699
{
36893700
auto entityIDs = std::vector<TEntityID>();
36903701
if (neighborEntities.empty())
3691-
{
36923702
return entityIDs;
3693-
}
36943703

3695-
if (neighborNo < neighborEntities.size())
3696-
{
3697-
std::nth_element(neighborEntities.begin(), std::next(neighborEntities.begin(), neighborNo - 1), neighborEntities.end());
3698-
}
3704+
if (neighborEntities.size() < neighborNo)
3705+
std::sort(neighborEntities.begin(), neighborEntities.end());
3706+
else
3707+
std::sort_heap(neighborEntities.begin(), neighborEntities.end());
36993708

3700-
auto const entityNo = std::min(neighborNo, neighborEntities.size());
3701-
auto const lastIt = std::next(neighborEntities.begin(), entityNo);
3709+
auto const entityNo = neighborEntities.size();
3710+
entityIDs.resize(entityNo);
3711+
for (std::size_t i = 0; i < entityNo; ++i)
3712+
entityIDs[i] = neighborEntities[i].EntityID;
37023713

3703-
entityIDs.reserve(entityNo);
3704-
std::transform(neighborEntities.begin(), lastIt, std::back_inserter(entityIDs), [](auto const& ed) { return ed.EntityID; });
37053714
return entityIDs;
37063715
}
37073716

@@ -3714,83 +3723,90 @@ namespace OrthoTree
37143723
return IGM::GetBoxWallDistanceAD(searchPoint, centerPoint, halfSize, isInsideConsideredAsZero);
37153724
}
37163725

3717-
void VisitNodesInDFSWithChildrenEdit(MortonNodeIDCR key, auto const& procedure, auto const& childNodeKeyEditor) const noexcept
3726+
void VisitNodesInDFSWithChildrenEdit(
3727+
depth_t stackID, std::pair<Node const*, TGeometry> const& nodeWithDistance, auto const& procedure, auto const& childNodeKeyEditor) const noexcept
37183728
{
3719-
auto const& node = this->GetNode(key);
3720-
if (!procedure(node))
3729+
if (!procedure(nodeWithDistance))
37213730
return;
37223731

3723-
for (auto const [childKey, _] : childNodeKeyEditor(node.GetChildren()))
3724-
this->VisitNodesInDFSWithChildrenEdit(childKey, procedure, childNodeKeyEditor);
3732+
auto const childStackID = stackID + 1;
3733+
for (auto const& childNodeWithDistance : childNodeKeyEditor(nodeWithDistance.first->GetChildren(), stackID))
3734+
this->VisitNodesInDFSWithChildrenEdit(childStackID, childNodeWithDistance, procedure, childNodeKeyEditor);
37253735
}
37263736

37273737
public:
3728-
// K Nearest Neighbor
3729-
std::vector<TEntityID> GetNearestNeighbors(TVector const& searchPoint, std::size_t neighborNo, TGeometry maxDistance, TContainer const& points) const noexcept
3738+
// Get K Nearest Neighbor sorted by distance (point distance should be less than maxDistanceWithin, it is used as a Tolerance check)
3739+
std::vector<TEntityID> GetNearestNeighbors(
3740+
TVector const& searchPoint, std::size_t neighborNo, TGeometry maxDistanceWithin, TContainer const& points) const noexcept
37303741
{
37313742
auto neighborEntities = std::vector<EntityDistance>();
3732-
auto [smallestNodeKey, smallesDepthID] = this->FindSmallestNodeKeyWithDepth(this->template GetNodeID<true>(searchPoint));
3743+
neighborEntities.reserve(neighborNo);
3744+
3745+
auto smallestNodeKey = this->FindSmallestNodeKey(this->template GetNodeID<true>(searchPoint));
37333746
if (!SI::IsValidKey(smallestNodeKey))
3734-
{
37353747
smallestNodeKey = SI::GetRootKey();
3736-
smallesDepthID = 0;
3737-
}
37383748

3739-
auto const& smallestNode = this->GetNode(smallestNodeKey);
3740-
3741-
auto farestEntityDistance = maxDistance;
3749+
auto farestEntityDistance = maxDistanceWithin;
37423750
// Parent checks (in a usual case parents do not have entities)
37433751
for (auto nodeKey = smallestNodeKey; SI::IsValidKey(nodeKey); nodeKey = SI::GetParentKey(nodeKey))
3744-
{
3745-
auto const& node = this->GetNode(nodeKey);
3746-
AddEntityDistance(this->GetNodeEntities(node), searchPoint, points, neighborEntities, farestEntityDistance);
3747-
farestEntityDistance = GetFarestDistance(neighborEntities, neighborNo, farestEntityDistance);
3748-
}
3752+
AddEntityDistance(this->GetNodeEntities(nodeKey), searchPoint, points, neighborEntities, neighborNo, farestEntityDistance);
37493753

37503754
// Search in itself and the children
3755+
auto childrenDistanceStack = std::vector<std::vector<std::pair<Node const*, TGeometry>>>(this->GetMaxDepthID());
37513756
for (auto nodeKey = smallestNodeKey, prevNodeKey = MortonNodeID{}; SI::IsValidKey(nodeKey);
37523757
prevNodeKey = nodeKey, nodeKey = SI::GetParentKey(nodeKey))
37533758
{
3754-
auto const wallDistance = this->GetNodeWallDistance(searchPoint, smallestNodeKey, smallestNode, false);
3759+
auto const node = this->GetNode(nodeKey);
3760+
auto const wallDistance = this->GetNodeWallDistance(searchPoint, nodeKey, node, false);
37553761
this->VisitNodesInDFSWithChildrenEdit(
3756-
nodeKey,
3757-
[&](Node const& node) {
3758-
if (node.GetKey() == nodeKey)
3762+
0,
3763+
{ &node, wallDistance },
3764+
[&](std::pair<Node const*, TGeometry> const& nodeDistance) {
3765+
auto const& [childNode, childNodeDistance] = nodeDistance;
3766+
auto const childNodeKey = childNode->GetKey();
3767+
if (childNodeKey == nodeKey)
37593768
return true; // Self check was already done.
37603769

3761-
if (node.GetKey() == prevNodeKey)
3770+
if (childNodeKey == prevNodeKey)
37623771
return false; // Previous subtree was already checked.
37633772

3764-
AddEntityDistance(this->GetNodeEntities(node), searchPoint, points, neighborEntities, farestEntityDistance);
3765-
farestEntityDistance = GetFarestDistance(neighborEntities, neighborNo, farestEntityDistance);
3773+
if (childNodeDistance > farestEntityDistance)
3774+
return false;
3775+
3776+
AddEntityDistance(this->GetNodeEntities(*childNode), searchPoint, points, neighborEntities, neighborNo, farestEntityDistance);
37663777
return true;
37673778
},
3768-
[&](auto const& children) {
3769-
auto childrenDistance = std::vector<std::pair<MortonNodeID, TGeometry>>();
3779+
[&](auto const& children, depth_t stackID) -> std::span<std::pair<Node const*, TGeometry>> {
3780+
if (children.empty())
3781+
return {};
3782+
3783+
auto& childrenDistance = childrenDistanceStack[stackID];
3784+
childrenDistance.clear();
37703785
for (MortonNodeIDCR childNodeKey : children)
37713786
{
37723787
auto const& childNode = this->m_nodes.at(childNodeKey);
37733788
auto const wallDistance = this->GetNodeWallDistance(searchPoint, childNodeKey, childNode, true);
37743789
if (wallDistance > farestEntityDistance)
37753790
continue;
37763791

3777-
childrenDistance.emplace_back(childNodeKey, wallDistance);
3792+
childrenDistance.emplace_back(&childNode, wallDistance);
37783793
}
37793794

37803795
std::sort(childrenDistance.begin(), childrenDistance.end(), [&](auto const& leftDistance, auto const& rightDistance) {
37813796
return leftDistance.second < rightDistance.second;
37823797
});
37833798

3784-
return childrenDistance;
3799+
return std::span(childrenDistance);
37853800
});
37863801

3787-
if (farestEntityDistance < wallDistance || maxDistance < wallDistance)
3802+
if (farestEntityDistance < wallDistance)
37883803
break;
37893804
}
37903805

37913806
return ConvertEntityDistanceToList(neighborEntities, neighborNo);
37923807
}
37933808

3809+
// Get K Nearest Neighbor sorted by distance
37943810
inline std::vector<TEntityID> GetNearestNeighbors(TVector const& searchPoint, std::size_t neighborNo, TContainer const& points) const noexcept
37953811
{
37963812
return this->GetNearestNeighbors(searchPoint, neighborNo, std::numeric_limits<TGeometry>::max(), points);

unittests/general.tests.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3451,19 +3451,16 @@ namespace LongIntAdaptor
34513451
auto actual = tree.GetNearestNeighbors(searchPoint, k, points);
34523452
Assert::IsTrue(expected.size() == actual.size());
34533453

3454-
auto const comp = [&](auto const e1, auto const e2) {
3455-
return AD::Distance2(searchPoint, points[e1]) < AD::Distance2(searchPoint, points[e2]);
3456-
};
3457-
std::ranges::sort(expected, comp);
3458-
std::ranges::sort(actual, comp);
3459-
34603454
auto const areResultsEqual = expected == actual;
34613455

34623456
if (areResultsEqual)
34633457
continue;
34643458

34653459
for (std::size_t i = 0; i < expected.size(); ++i)
34663460
{
3461+
if (expected[i] == actual[i])
3462+
continue;
3463+
34673464
auto const expectedDistance = AD::Distance2(searchPoint, points[expected[i]]);
34683465
auto const actualDistance = AD::Distance2(searchPoint, points[actual[i]]);
34693466
Assert::IsTrue(std::abs(actualDistance - expectedDistance) < std::numeric_limits<double>::epsilon() * 10.0);

0 commit comments

Comments
 (0)