From 85a98697bccf70bb43d2465c182115b2ef89fdae Mon Sep 17 00:00:00 2001 From: Nico Reissmann Date: Sat, 26 Apr 2025 07:43:57 +0200 Subject: [PATCH 1/3] TODO --- jlm/llvm/opt/DeadNodeElimination.hpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/jlm/llvm/opt/DeadNodeElimination.hpp b/jlm/llvm/opt/DeadNodeElimination.hpp index 5f3edd2a2..5b2bb52c3 100644 --- a/jlm/llvm/opt/DeadNodeElimination.hpp +++ b/jlm/llvm/opt/DeadNodeElimination.hpp @@ -34,6 +34,15 @@ namespace phi class node; } +class DNEStructuralNodeHandler +{ +public: + virtual ~DNEStructuralNodeHandler(); + + virtual void + SweepNode(rvsdg::StructuralNode & structuralNode) = 0; +}; + /** \brief Dead Node Elimination Optimization * * Dead Node Elimination removes all nodes that do not contribute to the result of a computation. A From 105836f7d30889d8ff3dea0ad218756386d1234b Mon Sep 17 00:00:00 2001 From: Nico Reissmann Date: Sat, 26 Apr 2025 21:25:35 +0200 Subject: [PATCH 2/3] TODO --- jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp | 6 +- jlm/llvm/opt/DeadNodeElimination.cpp | 241 ++++++++---------- jlm/llvm/opt/DeadNodeElimination.hpp | 123 ++++++++- .../opt/alias-analyses/MemoryStateEncoder.cpp | 2 +- jlm/tooling/Command.cpp | 3 +- .../jlm/llvm/opt/TestDeadNodeElimination.cpp | 3 +- tests/jlm/llvm/opt/test-unroll.cpp | 2 +- 7 files changed, 234 insertions(+), 146 deletions(-) diff --git a/jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp b/jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp index 5a3cc2684..aad8ea76b 100644 --- a/jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp +++ b/jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp @@ -58,7 +58,7 @@ void split_opt(llvm::RvsdgModule & rm) { // TODO: figure out which optimizations to use here - jlm::llvm::DeadNodeElimination dne; + jlm::llvm::DeadNodeElimination dne({ llvm::DNEGammaNodeHandler::GetInstance() }); jlm::hls::cne cne; jlm::llvm::InvariantValueRedirection ivr; jlm::llvm::tginversion tgi; @@ -76,7 +76,7 @@ void pre_opt(jlm::llvm::RvsdgModule & rm) { // TODO: figure out which optimizations to use here - jlm::llvm::DeadNodeElimination dne; + jlm::llvm::DeadNodeElimination dne({ llvm::DNEGammaNodeHandler::GetInstance() }); jlm::hls::cne cne; jlm::llvm::InvariantValueRedirection ivr; jlm::llvm::tginversion tgi; @@ -438,7 +438,7 @@ rvsdg2rhls(llvm::RvsdgModule & rhls, util::StatisticsCollector & collector) merge_gamma(rhls); - llvm::DeadNodeElimination llvmDne; + llvm::DeadNodeElimination llvmDne({ llvm::DNEGammaNodeHandler::GetInstance() }); llvmDne.Run(rhls, collector); mem_sep_argument(rhls); diff --git a/jlm/llvm/opt/DeadNodeElimination.cpp b/jlm/llvm/opt/DeadNodeElimination.cpp index 2bd3b7a86..391a303f2 100644 --- a/jlm/llvm/opt/DeadNodeElimination.cpp +++ b/jlm/llvm/opt/DeadNodeElimination.cpp @@ -16,78 +16,100 @@ namespace jlm::llvm { -/** \brief Dead Node Elimination context class - * - * This class keeps track of all the nodes and outputs that are alive. In contrast to all other - * nodes, a simple node is considered alive if already a single of its outputs is alive. For this - * reason, this class keeps separately track of simple nodes and therefore avoids to store all its - * outputs (and instead stores the node itself). By marking the entire node as alive, we also avoid - * that we reiterate through all inputs of this node again in the future. The following example - * illustrates the issue: - * - * o1 ... oN = Node2 i1 ... iN - * p1 ... pN = Node1 o1 ... oN - * - * When we mark o1 as alive, we actually mark the entire Node2 as alive. This means that when we try - * to mark o2 alive in the future, we can immediately stop marking instead of reiterating through i1 - * ... iN again. Thus, by marking the entire simple node instead of just its outputs, we reduce the - * runtime for marking Node2 from O(oN x iN) to O(oN + iN). - */ -class DeadNodeElimination::Context final +DNEStructuralNodeHandler::~DNEStructuralNodeHandler() = default; + +DNEGammaNodeHandler::~DNEGammaNodeHandler() = default; + +DNEGammaNodeHandler::DNEGammaNodeHandler() = default; + +std::type_index +DNEGammaNodeHandler::GetTypeInfo() const { -public: - void - MarkAlive(const jlm::rvsdg::output & output) - { - if (auto simpleOutput = dynamic_cast(&output)) - { - SimpleNodes_.Insert(simpleOutput->node()); - return; - } + return typeid(rvsdg::GammaNode); +} - Outputs_.Insert(&output); +std::optional> +DNEGammaNodeHandler::ComputeMarkPhaseContinuations(const rvsdg::output & output) const +{ + if (const auto gammaNode = rvsdg::TryGetRegionParentNode(output)) + { + return std::vector{ gammaNode->MapBranchArgumentEntryVar(output).input->origin() }; } - bool - IsAlive(const jlm::rvsdg::output & output) const noexcept + if (const auto gammaNode = rvsdg::TryGetOwnerNode(output)) { - if (auto simpleOutput = dynamic_cast(&output)) + std::vector continuations({ gammaNode->predicate()->origin() }); + for (const auto & result : gammaNode->MapOutputExitVar(output).branchResult) { - return SimpleNodes_.Contains(simpleOutput->node()); + continuations.push_back(result->origin()); } - - return Outputs_.Contains(&output); + return continuations; } - bool - IsAlive(const rvsdg::Node & node) const noexcept + return std::nullopt; +} + +void +DNEGammaNodeHandler::SweepNodeEntry( + rvsdg::StructuralNode & structuralNode, + const DNEContext & context) const +{ + auto & gammaNode = *util::AssertedCast(&structuralNode); + + // Remove dead arguments and inputs + for (size_t n = gammaNode.ninputs() - 1; n >= 1; n--) { - if (auto simpleNode = dynamic_cast(&node)) + auto input = gammaNode.input(n); + + bool alive = false; + for (auto & argument : input->arguments) { - return SimpleNodes_.Contains(simpleNode); + if (context.IsAlive(argument)) + { + alive = true; + break; + } } - - for (size_t n = 0; n < node.noutputs(); n++) + if (!alive) { - if (IsAlive(*node.output(n))) + for (size_t r = 0; r < gammaNode.nsubregions(); r++) { - return true; + gammaNode.subregion(r)->RemoveArgument(n - 1); } + gammaNode.RemoveInput(n); } - - return false; } +} - static std::unique_ptr - Create() +void +DNEGammaNodeHandler::SweepNodeExit( + rvsdg::StructuralNode & structuralNode, + const DNEContext & context) const +{ + auto & gammaNode = *util::AssertedCast(&structuralNode); + + // Remove dead outputs and results + for (size_t n = gammaNode.noutputs() - 1; n != static_cast(-1); n--) { - return std::make_unique(); + if (context.IsAlive(*gammaNode.output(n))) + { + continue; + } + + for (size_t r = 0; r < gammaNode.nsubregions(); r++) + { + gammaNode.subregion(r)->RemoveResult(n); + } + gammaNode.RemoveOutput(n); } +} -private: - util::HashSet SimpleNodes_; - util::HashSet Outputs_; -}; +DNEStructuralNodeHandler * +DNEGammaNodeHandler::GetInstance() +{ + static DNEGammaNodeHandler singleton; + return &singleton; +} /** \brief Dead Node Elimination statistics class * @@ -141,18 +163,17 @@ class DeadNodeElimination::Statistics final : public util::Statistics DeadNodeElimination::~DeadNodeElimination() noexcept = default; -DeadNodeElimination::DeadNodeElimination() = default; +DeadNodeElimination::DeadNodeElimination(std::vector handlers) + : Handlers_(std::move(handlers)) +{} void DeadNodeElimination::run(rvsdg::Region & region) { - Context_ = Context::Create(); + Context_ = DNEContext{}; MarkRegion(region); SweepRegion(region); - - // Discard internal state to free up memory after we are done - Context_.reset(); } void @@ -160,10 +181,11 @@ DeadNodeElimination::Run( rvsdg::RvsdgModule & module, util::StatisticsCollector & statisticsCollector) { - Context_ = Context::Create(); - auto & rvsdg = module.Rvsdg(); + + Context_ = DNEContext{}; auto statistics = Statistics::Create(module.SourceFilePath().value()); + statistics->StartMarkStatistics(rvsdg); MarkRegion(rvsdg.GetRootRegion()); statistics->StopMarkStatistics(); @@ -173,9 +195,6 @@ DeadNodeElimination::Run( statistics->StopSweepStatistics(rvsdg); statisticsCollector.CollectDemandedStatistics(std::move(statistics)); - - // Discard internal state to free up memory after we are done - Context_.reset(); } void @@ -190,32 +209,28 @@ DeadNodeElimination::MarkRegion(const rvsdg::Region & region) void DeadNodeElimination::MarkOutput(const jlm::rvsdg::output & output) { - if (Context_->IsAlive(output)) + if (Context_.IsAlive(output)) { return; } - Context_->MarkAlive(output); + Context_.MarkAlive(output); if (is(&output)) { return; } - if (auto gamma = rvsdg::TryGetOwnerNode(output)) + for (const auto handler : Handlers_) { - MarkOutput(*gamma->predicate()->origin()); - for (const auto & result : gamma->MapOutputExitVar(output).branchResult) + if (const auto continuations = handler->ComputeMarkPhaseContinuations(output)) { - MarkOutput(*result->origin()); + for (const auto & continuation : continuations.value()) + { + MarkOutput(*continuation); + } + return; } - return; - } - - if (auto gamma = rvsdg::TryGetRegionParentNode(output)) - { - MarkOutput(*gamma->MapBranchArgumentEntryVar(output).input->origin()); - return; } if (auto theta = rvsdg::TryGetOwnerNode(output)) @@ -316,7 +331,7 @@ DeadNodeElimination::SweepRvsdg(rvsdg::Graph & rvsdg) const // Remove dead imports for (size_t n = rvsdg.GetRootRegion().narguments() - 1; n != static_cast(-1); n--) { - if (!Context_->IsAlive(*rvsdg.GetRootRegion().argument(n))) + if (!Context_.IsAlive(*rvsdg.GetRootRegion().argument(n))) { rvsdg.GetRootRegion().RemoveArgument(n); } @@ -338,7 +353,7 @@ DeadNodeElimination::SweepRegion(rvsdg::Region & region) const { for (auto node : *it) { - if (!Context_->IsAlive(*node)) + if (!Context_.IsAlive(*node)) { remove(node); continue; @@ -357,10 +372,22 @@ DeadNodeElimination::SweepRegion(rvsdg::Region & region) const void DeadNodeElimination::SweepStructuralNode(rvsdg::StructuralNode & node) const { - auto sweepGamma = [](auto & d, auto & n) + for (const auto handler : Handlers_) { - d.SweepGamma(*util::AssertedCast(&n)); - }; + if (handler->GetTypeInfo() == typeid(node)) + { + handler->SweepNodeExit(node, Context_); + + for (size_t r = 0; r < node.nsubregions(); r++) + { + SweepRegion(*node.subregion(r)); + } + + handler->SweepNodeEntry(node, Context_); + return; + } + } + auto sweepTheta = [](auto & d, auto & n) { d.SweepTheta(*util::AssertedCast(&n)); @@ -381,8 +408,7 @@ DeadNodeElimination::SweepStructuralNode(rvsdg::StructuralNode & node) const static std::unordered_map< std::type_index, std::function> - map({ { typeid(rvsdg::GammaOperation), sweepGamma }, - { typeid(rvsdg::ThetaOperation), sweepTheta }, + map({ { typeid(rvsdg::ThetaOperation), sweepTheta }, { typeid(llvm::LlvmLambdaOperation), sweepLambda }, { typeid(phi::operation), sweepPhi }, { typeid(delta::operation), sweepDelta } }); @@ -392,55 +418,6 @@ DeadNodeElimination::SweepStructuralNode(rvsdg::StructuralNode & node) const map[typeid(op)](*this, node); } -void -DeadNodeElimination::SweepGamma(rvsdg::GammaNode & gammaNode) const -{ - // Remove dead outputs and results - for (size_t n = gammaNode.noutputs() - 1; n != static_cast(-1); n--) - { - if (Context_->IsAlive(*gammaNode.output(n))) - { - continue; - } - - for (size_t r = 0; r < gammaNode.nsubregions(); r++) - { - gammaNode.subregion(r)->RemoveResult(n); - } - gammaNode.RemoveOutput(n); - } - - // Sweep gamma subregions - for (size_t r = 0; r < gammaNode.nsubregions(); r++) - { - SweepRegion(*gammaNode.subregion(r)); - } - - // Remove dead arguments and inputs - for (size_t n = gammaNode.ninputs() - 1; n >= 1; n--) - { - auto input = gammaNode.input(n); - - bool alive = false; - for (auto & argument : input->arguments) - { - if (Context_->IsAlive(argument)) - { - alive = true; - break; - } - } - if (!alive) - { - for (size_t r = 0; r < gammaNode.nsubregions(); r++) - { - gammaNode.subregion(r)->RemoveArgument(n - 1); - } - gammaNode.RemoveInput(n); - } - } -} - void DeadNodeElimination::SweepTheta(rvsdg::ThetaNode & thetaNode) const { @@ -448,7 +425,7 @@ DeadNodeElimination::SweepTheta(rvsdg::ThetaNode & thetaNode) const std::vector loopvars; for (const auto & loopvar : thetaNode.GetLoopVars()) { - if (!Context_->IsAlive(*loopvar.pre) && !Context_->IsAlive(*loopvar.output)) + if (!Context_.IsAlive(*loopvar.pre) && !Context_.IsAlive(*loopvar.output)) { loopvar.post->divert_to(loopvar.pre); loopvars.push_back(loopvar); @@ -485,7 +462,7 @@ DeadNodeElimination::SweepPhi(phi::node & phiNode) const auto argument = output.argument(); // A recursion variable is only dead iff its output AND argument are dead - auto isDead = !Context_->IsAlive(output) && !Context_->IsAlive(*argument); + auto isDead = !Context_.IsAlive(output) && !Context_.IsAlive(*argument); if (isDead) { deadRecursionArguments.Insert(argument); diff --git a/jlm/llvm/opt/DeadNodeElimination.hpp b/jlm/llvm/opt/DeadNodeElimination.hpp index 5b2bb52c3..74714c19b 100644 --- a/jlm/llvm/opt/DeadNodeElimination.hpp +++ b/jlm/llvm/opt/DeadNodeElimination.hpp @@ -9,6 +9,8 @@ #include #include +#include + namespace jlm::rvsdg { class GammaNode; @@ -34,13 +36,123 @@ namespace phi class node; } +/** \brief Dead Node Elimination context class + * + * This class keeps track of all the nodes and outputs that are alive. In contrast to all other + * nodes, a simple node is considered alive if already a single of its outputs is alive. For this + * reason, this class keeps separately track of simple nodes and therefore avoids to store all its + * outputs (and instead stores the node itself). By marking the entire node as alive, we also avoid + * that we reiterate through all inputs of this node again in the future. The following example + * illustrates the issue: + * + * o1 ... oN = Node2 i1 ... iN + * p1 ... pN = Node1 o1 ... oN + * + * When we mark o1 as alive, we actually mark the entire Node2 as alive. This means that when we try + * to mark o2 alive in the future, we can immediately stop marking instead of reiterating through i1 + * ... iN again. Thus, by marking the entire simple node instead of just its outputs, we reduce the + * runtime for marking Node2 from O(oN x iN) to O(oN + iN). + */ +class DNEContext final +{ +public: + void + MarkAlive(const rvsdg::output & output) + { + if (const auto simpleOutput = dynamic_cast(&output)) + { + SimpleNodes_.Insert(simpleOutput->node()); + return; + } + + Outputs_.Insert(&output); + } + + bool + IsAlive(const rvsdg::output & output) const noexcept + { + if (const auto simpleOutput = dynamic_cast(&output)) + { + return SimpleNodes_.Contains(simpleOutput->node()); + } + + return Outputs_.Contains(&output); + } + + bool + IsAlive(const rvsdg::Node & node) const noexcept + { + if (const auto simpleNode = dynamic_cast(&node)) + { + return SimpleNodes_.Contains(simpleNode); + } + + for (size_t n = 0; n < node.noutputs(); n++) + { + if (IsAlive(*node.output(n))) + { + return true; + } + } + + return false; + } + +private: + util::HashSet SimpleNodes_; + util::HashSet Outputs_; +}; + class DNEStructuralNodeHandler { public: virtual ~DNEStructuralNodeHandler(); + virtual std::type_index + GetTypeInfo() const = 0; + + virtual std::optional> + ComputeMarkPhaseContinuations(const rvsdg::output & output) const = 0; + virtual void - SweepNode(rvsdg::StructuralNode & structuralNode) = 0; + SweepNodeEntry(rvsdg::StructuralNode & structuralNode, const DNEContext & context) const = 0; + + virtual void + SweepNodeExit(rvsdg::StructuralNode & structuralNode, const DNEContext & context) const = 0; +}; + +class DNEGammaNodeHandler final : public DNEStructuralNodeHandler +{ +public: + ~DNEGammaNodeHandler() override; + + DNEGammaNodeHandler(const DNEGammaNodeHandler &) = delete; + + DNEGammaNodeHandler & + operator=(const DNEGammaNodeHandler &) = delete; + + DNEGammaNodeHandler(DNEGammaNodeHandler &&) = delete; + + DNEGammaNodeHandler & + operator=(DNEGammaNodeHandler &&) = delete; + + std::type_index + GetTypeInfo() const override; + + std::optional> + ComputeMarkPhaseContinuations(const rvsdg::output & output) const override; + + void + SweepNodeEntry(rvsdg::StructuralNode & structuralNode, const DNEContext & context) const override; + + void + SweepNodeExit(rvsdg::StructuralNode & structuralNode, const DNEContext & context) const override; + + static DNEStructuralNodeHandler * + GetInstance(); + +private: + DNEGammaNodeHandler(); }; /** \brief Dead Node Elimination Optimization @@ -60,13 +172,12 @@ class DNEStructuralNodeHandler */ class DeadNodeElimination final : public rvsdg::Transformation { - class Context; class Statistics; public: ~DeadNodeElimination() noexcept override; - DeadNodeElimination(); + explicit DeadNodeElimination(std::vector handlers); DeadNodeElimination(const DeadNodeElimination &) = delete; @@ -100,9 +211,6 @@ class DeadNodeElimination final : public rvsdg::Transformation void SweepStructuralNode(rvsdg::StructuralNode & node) const; - void - SweepGamma(rvsdg::GammaNode & gammaNode) const; - void SweepTheta(rvsdg::ThetaNode & thetaNode) const; @@ -115,7 +223,8 @@ class DeadNodeElimination final : public rvsdg::Transformation static void SweepDelta(delta::node & deltaNode); - std::unique_ptr Context_; + DNEContext Context_; + std::vector Handlers_; }; } diff --git a/jlm/llvm/opt/alias-analyses/MemoryStateEncoder.cpp b/jlm/llvm/opt/alias-analyses/MemoryStateEncoder.cpp index d0ece5723..f755a65b3 100644 --- a/jlm/llvm/opt/alias-analyses/MemoryStateEncoder.cpp +++ b/jlm/llvm/opt/alias-analyses/MemoryStateEncoder.cpp @@ -480,7 +480,7 @@ MemoryStateEncoder::Encode( Context_.reset(); // Remove all nodes that became dead throughout the encoding. - DeadNodeElimination deadNodeElimination; + DeadNodeElimination deadNodeElimination({ DNEGammaNodeHandler::GetInstance() }); deadNodeElimination.Run(rvsdgModule, statisticsCollector); } diff --git a/jlm/tooling/Command.cpp b/jlm/tooling/Command.cpp index 01d9d3a25..0d4985b7e 100644 --- a/jlm/tooling/Command.cpp +++ b/jlm/tooling/Command.cpp @@ -407,7 +407,8 @@ JlmOptCommand::CreateTransformation( case JlmOptCommandLineOptions::OptimizationId::CommonNodeElimination: return std::make_unique(); case JlmOptCommandLineOptions::OptimizationId::DeadNodeElimination: - return std::make_unique(); + return std::unique_ptr( + new llvm::DeadNodeElimination({ llvm::DNEGammaNodeHandler::GetInstance() })); case JlmOptCommandLineOptions::OptimizationId::FunctionInlining: return std::make_unique(); case JlmOptCommandLineOptions::OptimizationId::IfConversion: diff --git a/tests/jlm/llvm/opt/TestDeadNodeElimination.cpp b/tests/jlm/llvm/opt/TestDeadNodeElimination.cpp index e36a0cf17..e97f741d2 100644 --- a/tests/jlm/llvm/opt/TestDeadNodeElimination.cpp +++ b/tests/jlm/llvm/opt/TestDeadNodeElimination.cpp @@ -21,8 +21,9 @@ static void RunDeadNodeElimination(jlm::llvm::RvsdgModule & rvsdgModule) { + using namespace jlm::llvm; jlm::util::StatisticsCollector statisticsCollector; - jlm::llvm::DeadNodeElimination deadNodeElimination; + DeadNodeElimination deadNodeElimination({ DNEGammaNodeHandler::GetInstance() }); deadNodeElimination.Run(rvsdgModule, statisticsCollector); } diff --git a/tests/jlm/llvm/opt/test-unroll.cpp b/tests/jlm/llvm/opt/test-unroll.cpp index b5baa11c0..6ddc7dc3b 100644 --- a/tests/jlm/llvm/opt/test-unroll.cpp +++ b/tests/jlm/llvm/opt/test-unroll.cpp @@ -257,7 +257,7 @@ test_unknown_boundaries() assert(jlm::rvsdg::is(node)); /* Create cleaner output */ - DeadNodeElimination dne; + DeadNodeElimination dne({ DNEGammaNodeHandler::GetInstance() }); dne.Run(rm, statisticsCollector); // jlm::rvsdg::view(graph, stdout); } From f16ee696307a2ece91c21d0db7c9cf272d0b9597 Mon Sep 17 00:00:00 2001 From: Nico Reissmann Date: Sat, 26 Apr 2025 21:35:08 +0200 Subject: [PATCH 3/3] TODO --- jlm/llvm/opt/DeadNodeElimination.cpp | 34 ++++++++++++++++------------ jlm/llvm/opt/DeadNodeElimination.hpp | 4 ++-- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/jlm/llvm/opt/DeadNodeElimination.cpp b/jlm/llvm/opt/DeadNodeElimination.cpp index 391a303f2..354124ddf 100644 --- a/jlm/llvm/opt/DeadNodeElimination.cpp +++ b/jlm/llvm/opt/DeadNodeElimination.cpp @@ -163,9 +163,15 @@ class DeadNodeElimination::Statistics final : public util::Statistics DeadNodeElimination::~DeadNodeElimination() noexcept = default; -DeadNodeElimination::DeadNodeElimination(std::vector handlers) - : Handlers_(std::move(handlers)) -{} +DeadNodeElimination::DeadNodeElimination( + const std::vector & handlers) +{ + for (const auto handler : handlers) + { + JLM_ASSERT(Handlers_.find(handler->GetTypeInfo()) == Handlers_.end()); + Handlers_[handler->GetTypeInfo()] = handler; + } +} void DeadNodeElimination::run(rvsdg::Region & region) @@ -221,7 +227,7 @@ DeadNodeElimination::MarkOutput(const jlm::rvsdg::output & output) return; } - for (const auto handler : Handlers_) + for (const auto [_, handler] : Handlers_) { if (const auto continuations = handler->ComputeMarkPhaseContinuations(output)) { @@ -372,20 +378,18 @@ DeadNodeElimination::SweepRegion(rvsdg::Region & region) const void DeadNodeElimination::SweepStructuralNode(rvsdg::StructuralNode & node) const { - for (const auto handler : Handlers_) + if (const auto it = Handlers_.find(typeid(node)); it != Handlers_.end()) { - if (handler->GetTypeInfo() == typeid(node)) - { - handler->SweepNodeExit(node, Context_); - - for (size_t r = 0; r < node.nsubregions(); r++) - { - SweepRegion(*node.subregion(r)); - } + const auto handler = it->second; + handler->SweepNodeExit(node, Context_); - handler->SweepNodeEntry(node, Context_); - return; + for (size_t r = 0; r < node.nsubregions(); r++) + { + SweepRegion(*node.subregion(r)); } + + handler->SweepNodeEntry(node, Context_); + return; } auto sweepTheta = [](auto & d, auto & n) diff --git a/jlm/llvm/opt/DeadNodeElimination.hpp b/jlm/llvm/opt/DeadNodeElimination.hpp index 74714c19b..adb2f371e 100644 --- a/jlm/llvm/opt/DeadNodeElimination.hpp +++ b/jlm/llvm/opt/DeadNodeElimination.hpp @@ -177,7 +177,7 @@ class DeadNodeElimination final : public rvsdg::Transformation public: ~DeadNodeElimination() noexcept override; - explicit DeadNodeElimination(std::vector handlers); + explicit DeadNodeElimination(const std::vector & handlers); DeadNodeElimination(const DeadNodeElimination &) = delete; @@ -224,7 +224,7 @@ class DeadNodeElimination final : public rvsdg::Transformation SweepDelta(delta::node & deltaNode); DNEContext Context_; - std::vector Handlers_; + std::unordered_map Handlers_; }; }