Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void
split_opt(llvm::RvsdgModule & rm)
{
// TODO: figure out which optimizations to use here
jlm::llvm::DeadNodeElimination dne;
jlm::llvm::DeadNodeElimination dne({ llvm::DNEGammaNodeHandler::GetInstance() });
jlm::hls::cne cne;
jlm::llvm::InvariantValueRedirection ivr;
jlm::llvm::tginversion tgi;
Expand All @@ -76,7 +76,7 @@ void
pre_opt(jlm::llvm::RvsdgModule & rm)
{
// TODO: figure out which optimizations to use here
jlm::llvm::DeadNodeElimination dne;
jlm::llvm::DeadNodeElimination dne({ llvm::DNEGammaNodeHandler::GetInstance() });
jlm::hls::cne cne;
jlm::llvm::InvariantValueRedirection ivr;
jlm::llvm::tginversion tgi;
Expand Down Expand Up @@ -438,7 +438,7 @@ rvsdg2rhls(llvm::RvsdgModule & rhls, util::StatisticsCollector & collector)

merge_gamma(rhls);

llvm::DeadNodeElimination llvmDne;
llvm::DeadNodeElimination llvmDne({ llvm::DNEGammaNodeHandler::GetInstance() });
llvmDne.Run(rhls, collector);

mem_sep_argument(rhls);
Expand Down
245 changes: 113 additions & 132 deletions jlm/llvm/opt/DeadNodeElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,78 +16,100 @@
namespace jlm::llvm
{

/** \brief Dead Node Elimination context class
*
* This class keeps track of all the nodes and outputs that are alive. In contrast to all other
* nodes, a simple node is considered alive if already a single of its outputs is alive. For this
* reason, this class keeps separately track of simple nodes and therefore avoids to store all its
* outputs (and instead stores the node itself). By marking the entire node as alive, we also avoid
* that we reiterate through all inputs of this node again in the future. The following example
* illustrates the issue:
*
* o1 ... oN = Node2 i1 ... iN
* p1 ... pN = Node1 o1 ... oN
*
* When we mark o1 as alive, we actually mark the entire Node2 as alive. This means that when we try
* to mark o2 alive in the future, we can immediately stop marking instead of reiterating through i1
* ... iN again. Thus, by marking the entire simple node instead of just its outputs, we reduce the
* runtime for marking Node2 from O(oN x iN) to O(oN + iN).
*/
class DeadNodeElimination::Context final
DNEStructuralNodeHandler::~DNEStructuralNodeHandler() = default;

DNEGammaNodeHandler::~DNEGammaNodeHandler() = default;

DNEGammaNodeHandler::DNEGammaNodeHandler() = default;

std::type_index
DNEGammaNodeHandler::GetTypeInfo() const
{
public:
void
MarkAlive(const jlm::rvsdg::output & output)
{
if (auto simpleOutput = dynamic_cast<const rvsdg::SimpleOutput *>(&output))
{
SimpleNodes_.Insert(simpleOutput->node());
return;
}
return typeid(rvsdg::GammaNode);
}

Outputs_.Insert(&output);
std::optional<std::vector<rvsdg::output *>>
DNEGammaNodeHandler::ComputeMarkPhaseContinuations(const rvsdg::output & output) const
{
if (const auto gammaNode = rvsdg::TryGetRegionParentNode<rvsdg::GammaNode>(output))
{
return std::vector{ gammaNode->MapBranchArgumentEntryVar(output).input->origin() };
}

bool
IsAlive(const jlm::rvsdg::output & output) const noexcept
if (const auto gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(output))
{
if (auto simpleOutput = dynamic_cast<const rvsdg::SimpleOutput *>(&output))
std::vector continuations({ gammaNode->predicate()->origin() });
for (const auto & result : gammaNode->MapOutputExitVar(output).branchResult)
{
return SimpleNodes_.Contains(simpleOutput->node());
continuations.push_back(result->origin());
}

return Outputs_.Contains(&output);
return continuations;
}

bool
IsAlive(const rvsdg::Node & node) const noexcept
return std::nullopt;
}

void
DNEGammaNodeHandler::SweepNodeEntry(
rvsdg::StructuralNode & structuralNode,
const DNEContext & context) const
{
auto & gammaNode = *util::AssertedCast<rvsdg::GammaNode>(&structuralNode);

// Remove dead arguments and inputs
for (size_t n = gammaNode.ninputs() - 1; n >= 1; n--)
{
if (auto simpleNode = dynamic_cast<const jlm::rvsdg::SimpleNode *>(&node))
auto input = gammaNode.input(n);

bool alive = false;
for (auto & argument : input->arguments)
{
return SimpleNodes_.Contains(simpleNode);
if (context.IsAlive(argument))
{
alive = true;
break;
}
}

for (size_t n = 0; n < node.noutputs(); n++)
if (!alive)
{
if (IsAlive(*node.output(n)))
for (size_t r = 0; r < gammaNode.nsubregions(); r++)
{
return true;
gammaNode.subregion(r)->RemoveArgument(n - 1);
}
gammaNode.RemoveInput(n);
}

return false;
}
}

void
DNEGammaNodeHandler::SweepNodeExit(
rvsdg::StructuralNode & structuralNode,
const DNEContext & context) const
{
auto & gammaNode = *util::AssertedCast<rvsdg::GammaNode>(&structuralNode);

static std::unique_ptr<Context>
Create()
// Remove dead outputs and results
for (size_t n = gammaNode.noutputs() - 1; n != static_cast<size_t>(-1); n--)
{
return std::make_unique<Context>();
if (context.IsAlive(*gammaNode.output(n)))
{
continue;
}

for (size_t r = 0; r < gammaNode.nsubregions(); r++)
{
gammaNode.subregion(r)->RemoveResult(n);
}
gammaNode.RemoveOutput(n);
}
}

private:
util::HashSet<const jlm::rvsdg::SimpleNode *> SimpleNodes_;
util::HashSet<const jlm::rvsdg::output *> Outputs_;
};
DNEStructuralNodeHandler *
DNEGammaNodeHandler::GetInstance()
{
static DNEGammaNodeHandler singleton;
return &singleton;
}

/** \brief Dead Node Elimination statistics class
*
Expand Down Expand Up @@ -141,29 +163,35 @@ class DeadNodeElimination::Statistics final : public util::Statistics

DeadNodeElimination::~DeadNodeElimination() noexcept = default;

DeadNodeElimination::DeadNodeElimination() = default;
DeadNodeElimination::DeadNodeElimination(
const std::vector<const DNEStructuralNodeHandler *> & handlers)
{
for (const auto handler : handlers)
{
JLM_ASSERT(Handlers_.find(handler->GetTypeInfo()) == Handlers_.end());
Handlers_[handler->GetTypeInfo()] = handler;
}
}

void
DeadNodeElimination::run(rvsdg::Region & region)
{
Context_ = Context::Create();
Context_ = DNEContext{};

MarkRegion(region);
SweepRegion(region);

// Discard internal state to free up memory after we are done
Context_.reset();
}

void
DeadNodeElimination::Run(
rvsdg::RvsdgModule & module,
util::StatisticsCollector & statisticsCollector)
{
Context_ = Context::Create();

auto & rvsdg = module.Rvsdg();

Context_ = DNEContext{};
auto statistics = Statistics::Create(module.SourceFilePath().value());

statistics->StartMarkStatistics(rvsdg);
MarkRegion(rvsdg.GetRootRegion());
statistics->StopMarkStatistics();
Expand All @@ -173,9 +201,6 @@ DeadNodeElimination::Run(
statistics->StopSweepStatistics(rvsdg);

statisticsCollector.CollectDemandedStatistics(std::move(statistics));

// Discard internal state to free up memory after we are done
Context_.reset();
}

void
Expand All @@ -190,32 +215,28 @@ DeadNodeElimination::MarkRegion(const rvsdg::Region & region)
void
DeadNodeElimination::MarkOutput(const jlm::rvsdg::output & output)
{
if (Context_->IsAlive(output))
if (Context_.IsAlive(output))
{
return;
}

Context_->MarkAlive(output);
Context_.MarkAlive(output);

if (is<rvsdg::GraphImport>(&output))
{
return;
}

if (auto gamma = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(output))
for (const auto [_, handler] : Handlers_)
{
MarkOutput(*gamma->predicate()->origin());
for (const auto & result : gamma->MapOutputExitVar(output).branchResult)
if (const auto continuations = handler->ComputeMarkPhaseContinuations(output))
{
MarkOutput(*result->origin());
for (const auto & continuation : continuations.value())
{
MarkOutput(*continuation);
}
return;
}
return;
}

if (auto gamma = rvsdg::TryGetRegionParentNode<rvsdg::GammaNode>(output))
{
MarkOutput(*gamma->MapBranchArgumentEntryVar(output).input->origin());
return;
}

if (auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(output))
Expand Down Expand Up @@ -316,7 +337,7 @@ DeadNodeElimination::SweepRvsdg(rvsdg::Graph & rvsdg) const
// Remove dead imports
for (size_t n = rvsdg.GetRootRegion().narguments() - 1; n != static_cast<size_t>(-1); n--)
{
if (!Context_->IsAlive(*rvsdg.GetRootRegion().argument(n)))
if (!Context_.IsAlive(*rvsdg.GetRootRegion().argument(n)))
{
rvsdg.GetRootRegion().RemoveArgument(n);
}
Expand All @@ -338,7 +359,7 @@ DeadNodeElimination::SweepRegion(rvsdg::Region & region) const
{
for (auto node : *it)
{
if (!Context_->IsAlive(*node))
if (!Context_.IsAlive(*node))
{
remove(node);
continue;
Expand All @@ -357,10 +378,20 @@ DeadNodeElimination::SweepRegion(rvsdg::Region & region) const
void
DeadNodeElimination::SweepStructuralNode(rvsdg::StructuralNode & node) const
{
auto sweepGamma = [](auto & d, auto & n)
if (const auto it = Handlers_.find(typeid(node)); it != Handlers_.end())
{
d.SweepGamma(*util::AssertedCast<rvsdg::GammaNode>(&n));
};
const auto handler = it->second;
handler->SweepNodeExit(node, Context_);

for (size_t r = 0; r < node.nsubregions(); r++)
{
SweepRegion(*node.subregion(r));
}

handler->SweepNodeEntry(node, Context_);
return;
}

auto sweepTheta = [](auto & d, auto & n)
{
d.SweepTheta(*util::AssertedCast<rvsdg::ThetaNode>(&n));
Expand All @@ -381,8 +412,7 @@ DeadNodeElimination::SweepStructuralNode(rvsdg::StructuralNode & node) const
static std::unordered_map<
std::type_index,
std::function<void(const DeadNodeElimination &, rvsdg::StructuralNode &)>>
map({ { typeid(rvsdg::GammaOperation), sweepGamma },
{ typeid(rvsdg::ThetaOperation), sweepTheta },
map({ { typeid(rvsdg::ThetaOperation), sweepTheta },
{ typeid(llvm::LlvmLambdaOperation), sweepLambda },
{ typeid(phi::operation), sweepPhi },
{ typeid(delta::operation), sweepDelta } });
Expand All @@ -392,63 +422,14 @@ DeadNodeElimination::SweepStructuralNode(rvsdg::StructuralNode & node) const
map[typeid(op)](*this, node);
}

void
DeadNodeElimination::SweepGamma(rvsdg::GammaNode & gammaNode) const
{
// Remove dead outputs and results
for (size_t n = gammaNode.noutputs() - 1; n != static_cast<size_t>(-1); n--)
{
if (Context_->IsAlive(*gammaNode.output(n)))
{
continue;
}

for (size_t r = 0; r < gammaNode.nsubregions(); r++)
{
gammaNode.subregion(r)->RemoveResult(n);
}
gammaNode.RemoveOutput(n);
}

// Sweep gamma subregions
for (size_t r = 0; r < gammaNode.nsubregions(); r++)
{
SweepRegion(*gammaNode.subregion(r));
}

// Remove dead arguments and inputs
for (size_t n = gammaNode.ninputs() - 1; n >= 1; n--)
{
auto input = gammaNode.input(n);

bool alive = false;
for (auto & argument : input->arguments)
{
if (Context_->IsAlive(argument))
{
alive = true;
break;
}
}
if (!alive)
{
for (size_t r = 0; r < gammaNode.nsubregions(); r++)
{
gammaNode.subregion(r)->RemoveArgument(n - 1);
}
gammaNode.RemoveInput(n);
}
}
}

void
DeadNodeElimination::SweepTheta(rvsdg::ThetaNode & thetaNode) const
{
// Determine loop variables to be removed.
std::vector<rvsdg::ThetaNode::LoopVar> loopvars;
for (const auto & loopvar : thetaNode.GetLoopVars())
{
if (!Context_->IsAlive(*loopvar.pre) && !Context_->IsAlive(*loopvar.output))
if (!Context_.IsAlive(*loopvar.pre) && !Context_.IsAlive(*loopvar.output))
{
loopvar.post->divert_to(loopvar.pre);
loopvars.push_back(loopvar);
Expand Down Expand Up @@ -485,7 +466,7 @@ DeadNodeElimination::SweepPhi(phi::node & phiNode) const
auto argument = output.argument();

// A recursion variable is only dead iff its output AND argument are dead
auto isDead = !Context_->IsAlive(output) && !Context_->IsAlive(*argument);
auto isDead = !Context_.IsAlive(output) && !Context_.IsAlive(*argument);
if (isDead)
{
deadRecursionArguments.Insert(argument);
Expand Down
Loading