Skip to content

Commit bb93ac8

Browse files
committed
Reimplement forward graph traversal to account for multiple inputs
1 parent 32e6cb8 commit bb93ac8

File tree

1 file changed

+48
-8
lines changed

1 file changed

+48
-8
lines changed

src/pagoda/graph/traversal/forward.cpp

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "forward.h"
2+
#include "pagoda/graph/node.h"
23

4+
#include <algorithm>
5+
#include <map>
36
#include <pagoda/graph/query/input_node.h>
47
#include <iterator>
58

@@ -9,8 +12,51 @@ namespace pagoda::graph::traversal
912
{
1013
Forward::Forward(Graph& graph) : Traversal(graph)
1114
{
12-
query::InputNode q(graph, [this](NodePtr n) { m_nodesToVisit.push(n); });
15+
std::queue<NodePtr> nodes;
16+
std::map<NodePtr, uint32_t> nodeDistances;
17+
query::InputNode q(graph, [&nodes, &nodeDistances](NodePtr n) {
18+
nodes.push(n);
19+
nodeDistances.emplace(n, 0);
20+
});
1321
graph.ExecuteQuery(q);
22+
23+
while (!nodes.empty()) {
24+
NodePtr front = nodes.front();
25+
nodes.pop();
26+
const uint32_t thisNodeDistance = nodeDistances[front];
27+
28+
NodeSet outNodes;
29+
GetOutputNodes(front, std::inserter(outNodes, std::end(outNodes)));
30+
for (auto n : outNodes) {
31+
auto iter = nodeDistances.find(n);
32+
const auto nextNodeDistance = iter == nodeDistances.end() ?
33+
0 : // First time seeing this node
34+
iter->second; // We've seen this node
35+
36+
// Update distance
37+
if (thisNodeDistance + 1 > nextNodeDistance) {
38+
// We have found a longer path
39+
nodeDistances[n] = iter->second = thisNodeDistance + 1;
40+
}
41+
42+
nodes.push(n);
43+
}
44+
}
45+
46+
std::vector<NodePtr> sortedNodes;
47+
sortedNodes.reserve(nodeDistances.size());
48+
for (const auto& [n, dist] : nodeDistances) {
49+
sortedNodes.emplace_back(n);
50+
}
51+
52+
std::sort(sortedNodes.begin(), sortedNodes.end(),
53+
[&nodeDistances] (const NodePtr& lhs, const NodePtr& rhs) {
54+
return nodeDistances[lhs] < nodeDistances[rhs];
55+
});
56+
57+
for (const auto& n : sortedNodes) {
58+
m_nodesToVisit.push(n);
59+
}
1460
}
1561

1662
Forward::~Forward() {}
@@ -19,13 +65,7 @@ NodePtr Forward::Get() { return m_nodesToVisit.front(); }
1965

2066
bool Forward::Advance()
2167
{
22-
auto front = m_nodesToVisit.front();
23-
m_nodesToVisit.pop();
24-
NodeSet outNodes;
25-
GetOutputNodes(front, std::inserter(outNodes, std::end(outNodes)));
26-
for (auto n : outNodes) {
27-
m_nodesToVisit.push(n);
28-
}
68+
m_nodesToVisit.pop();
2969
return HasNext();
3070
}
3171

0 commit comments

Comments
 (0)