From ee704ddc8dd241da864dc8d8ba5476f1415fa588 Mon Sep 17 00:00:00 2001 From: Helge Bahmann Date: Sat, 3 May 2025 01:47:31 +0200 Subject: [PATCH] Separate SimpleOperation / StructuralOperation Separate the inheritance hierarchies of SimpleOperation / StructuralOperation. This makes "SimpleOperation" the root class for all operational nodes. The inheritance hierarchy for StructuralOperation is kept for the moment, eventually the "contents" of structural operations can be entirely disassociated. They should evolve into something that can carry information specific to the structure, particularly backend-specific customizations of structural nodes. --- .../rhls2firrtl/RhlsToFirrtlConverter.cpp | 12 +- .../rhls2firrtl/RhlsToFirrtlConverter.hpp | 4 +- jlm/hls/backend/rhls2firrtl/dot-hls.cpp | 14 +- jlm/hls/backend/rvsdg2rhls/add-prints.cpp | 25 +-- jlm/hls/backend/rvsdg2rhls/alloca-conv.cpp | 164 +++++++++--------- jlm/hls/backend/rvsdg2rhls/instrument-ref.cpp | 156 +++++++++-------- jlm/hls/backend/rvsdg2rhls/mem-queue.cpp | 2 +- .../rvsdg2rhls/remove-redundant-buf.cpp | 4 +- jlm/hls/backend/rvsdg2rhls/rhls-dne.cpp | 51 +++--- jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp | 136 ++++++++------- jlm/hls/ir/hls.cpp | 2 +- jlm/hls/ir/hls.hpp | 2 +- jlm/hls/opt/IOBarrierRemoval.cpp | 11 +- jlm/hls/opt/cne.cpp | 31 ++-- jlm/llvm/backend/RvsdgToIpGraphConverter.cpp | 2 +- jlm/llvm/ir/operators/operators.cpp | 3 +- jlm/llvm/ir/operators/sext.cpp | 3 +- jlm/llvm/opt/InvariantValueRedirection.cpp | 21 ++- .../TopDownModRefEliminator.cpp | 15 +- jlm/llvm/opt/cne.cpp | 29 ++-- jlm/llvm/opt/push.cpp | 2 +- jlm/llvm/opt/reduction.cpp | 12 +- jlm/llvm/opt/reduction.hpp | 9 +- jlm/llvm/opt/unroll.hpp | 23 ++- jlm/rvsdg/NodeNormalization.hpp | 2 +- jlm/rvsdg/binary.cpp | 30 +++- jlm/rvsdg/bitstring/bitoperation-classes.cpp | 12 +- jlm/rvsdg/bitstring/concat.cpp | 18 +- jlm/rvsdg/bitstring/slice.cpp | 12 +- jlm/rvsdg/control.cpp | 3 +- jlm/rvsdg/node.hpp | 13 -- jlm/rvsdg/simple-node.cpp | 6 +- jlm/rvsdg/simple-node.hpp | 13 +- jlm/rvsdg/structural-node.hpp | 3 + ...InvariantLambdaMemoryStateRemovalTests.cpp | 9 +- tests/jlm/llvm/ir/operators/LoadTests.cpp | 9 +- tests/jlm/llvm/ir/operators/StoreTests.cpp | 12 +- tests/jlm/llvm/ir/operators/TestCall.cpp | 6 +- tests/jlm/llvm/ir/operators/test-sext.cpp | 8 +- tests/jlm/llvm/opt/IfConversionTests.cpp | 7 +- .../TestIntegerOperationsJlmToMlirToJlm.cpp | 8 +- tests/jlm/mlir/TestJlmToMlirToJlm.cpp | 105 ++++++----- .../mlir/frontend/TestMlirToJlmConverter.cpp | 14 +- tests/jlm/rvsdg/SimpleOperationTests.cpp | 40 +++-- tests/jlm/rvsdg/bitstring/bitstring.cpp | 154 ++++++++-------- tests/jlm/rvsdg/test-cse.cpp | 14 +- tests/jlm/rvsdg/test-gamma.cpp | 2 +- 47 files changed, 681 insertions(+), 552 deletions(-) diff --git a/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.cpp b/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.cpp index 5fffd499e..af54ef179 100644 --- a/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.cpp +++ b/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.cpp @@ -1153,10 +1153,10 @@ RhlsToFirrtlConverter::MlirGenHlsLocalMem(const jlm::rvsdg::SimpleNode * node) { auto lmem_op = dynamic_cast(&(node->GetOperation())); JLM_ASSERT(lmem_op); - auto res_node = rvsdg::TryGetOwnerNode(**node->output(0)->begin()); + auto res_node = rvsdg::TryGetOwnerNode(**node->output(0)->begin()); auto res_op = dynamic_cast(&res_node->GetOperation()); JLM_ASSERT(res_op); - auto req_node = rvsdg::TryGetOwnerNode(**node->output(1)->begin()); + auto req_node = rvsdg::TryGetOwnerNode(**node->output(1)->begin()); auto req_op = dynamic_cast(&req_node->GetOperation()); JLM_ASSERT(req_op); // Create the module and its input/output ports - we use a non-standard way here @@ -2788,8 +2788,8 @@ RhlsToFirrtlConverter::createInstances( { if (auto sn = dynamic_cast(node)) { - if (dynamic_cast(&(node->GetOperation())) - || dynamic_cast(&(node->GetOperation()))) + if (dynamic_cast(&(sn->GetOperation())) + || dynamic_cast(&(sn->GetOperation()))) { // these are virtual - connections go to local_mem instead continue; @@ -3993,7 +3993,7 @@ RhlsToFirrtlConverter::GetFirrtlType(const jlm::rvsdg::Type * type) } std::string -RhlsToFirrtlConverter::GetModuleName(const rvsdg::Node * node) +RhlsToFirrtlConverter::GetModuleName(const rvsdg::SimpleNode * node) { std::string append = ""; @@ -4085,7 +4085,7 @@ RhlsToFirrtlConverter::IsIdentityMapping(const jlm::rvsdg::match_op & op) void RhlsToFirrtlConverter::WriteModuleToFile( const circt::firrtl::FModuleOp fModuleOp, - const rvsdg::Node * node) + const rvsdg::SimpleNode * node) { if (!fModuleOp) return; diff --git a/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.hpp b/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.hpp index ac61d9b41..662e61422 100644 --- a/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.hpp +++ b/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.hpp @@ -68,7 +68,7 @@ class RhlsToFirrtlConverter : public BaseHLS MlirGen(const rvsdg::LambdaNode * lamdaNode); void - WriteModuleToFile(const circt::firrtl::FModuleOp fModuleOp, const rvsdg::Node * node); + WriteModuleToFile(const circt::firrtl::FModuleOp fModuleOp, const rvsdg::SimpleNode * node); void WriteCircuitToFile(const circt::firrtl::CircuitOp circuit, std::string name); @@ -283,7 +283,7 @@ class RhlsToFirrtlConverter : public BaseHLS circt::firrtl::FIRRTLBaseType GetFirrtlType(const jlm::rvsdg::Type * type); std::string - GetModuleName(const rvsdg::Node * node); + GetModuleName(const rvsdg::SimpleNode * node); bool IsIdentityMapping(const jlm::rvsdg::match_op & op); diff --git a/jlm/hls/backend/rhls2firrtl/dot-hls.cpp b/jlm/hls/backend/rhls2firrtl/dot-hls.cpp index a770455bc..62b9ad727 100644 --- a/jlm/hls/backend/rhls2firrtl/dot-hls.cpp +++ b/jlm/hls/backend/rhls2firrtl/dot-hls.cpp @@ -248,8 +248,10 @@ DotHLS::loop_to_dot(hls::loop_node * ln) dot << "{rank=same "; for (auto node : rvsdg::TopDownTraverser(sr)) { - auto mx = dynamic_cast(&node->GetOperation()); - auto lc = dynamic_cast(&node->GetOperation()); + auto simpleNode = dynamic_cast(node); + auto mx = dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); + auto lc = dynamic_cast( + simpleNode ? &simpleNode->GetOperation() : nullptr); if ((mx && !mx->discarding && mx->loop) || lc) { dot << get_node_name(node) << " "; @@ -260,7 +262,9 @@ DotHLS::loop_to_dot(hls::loop_node * ln) dot << "{rank=same "; for (auto node : rvsdg::TopDownTraverser(sr)) { - auto br = dynamic_cast(&node->GetOperation()); + auto simpleNode = dynamic_cast(node); + auto br = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); if (br && br->loop) { dot << get_node_name(node) << " "; @@ -272,9 +276,9 @@ DotHLS::loop_to_dot(hls::loop_node * ln) // do edges outside in order not to pull other nodes into the cluster for (auto node : rvsdg::TopDownTraverser(sr)) { - if (dynamic_cast(node)) + if (auto simpleNode = dynamic_cast(node)) { - auto mx = dynamic_cast(&node->GetOperation()); + auto mx = dynamic_cast(&simpleNode->GetOperation()); auto node_name = get_node_name(node); for (size_t i = 0; i < node->ninputs(); ++i) { diff --git a/jlm/hls/backend/rvsdg2rhls/add-prints.cpp b/jlm/hls/backend/rvsdg2rhls/add-prints.cpp index ebf7a4342..2b82901d4 100644 --- a/jlm/hls/backend/rvsdg2rhls/add-prints.cpp +++ b/jlm/hls/backend/rvsdg2rhls/add-prints.cpp @@ -115,20 +115,23 @@ convert_prints( convert_prints(structnode->subregion(n), printf, functionType); } } - else if (auto po = dynamic_cast(&(node->GetOperation()))) + else if (auto simpleNode = dynamic_cast(node)) { - auto printf_local = route_to_region_rvsdg(printf, region); // TODO: prevent repetition? - auto & constantNode = llvm::IntegerConstantOperation::Create(*region, 64, po->id()); - jlm::rvsdg::output * val = node->input(0)->origin(); - if (*val->Type() != *jlm::rvsdg::bittype::Create(64)) + if (auto po = dynamic_cast(&(simpleNode->GetOperation()))) { - auto bt = std::dynamic_pointer_cast(val->Type()); - JLM_ASSERT(bt); - val = &llvm::ZExtOperation::Create(*val, rvsdg::bittype::Create(64)); + auto printf_local = route_to_region_rvsdg(printf, region); // TODO: prevent repetition? + auto & constantNode = llvm::IntegerConstantOperation::Create(*region, 64, po->id()); + jlm::rvsdg::output * val = node->input(0)->origin(); + if (*val->Type() != *jlm::rvsdg::bittype::Create(64)) + { + auto bt = std::dynamic_pointer_cast(val->Type()); + JLM_ASSERT(bt); + val = &llvm::ZExtOperation::Create(*val, rvsdg::bittype::Create(64)); + } + llvm::CallOperation::Create(printf_local, functionType, { constantNode.output(0), val }); + node->output(0)->divert_users(node->input(0)->origin()); + jlm::rvsdg::remove(node); } - llvm::CallOperation::Create(printf_local, functionType, { constantNode.output(0), val }); - node->output(0)->divert_users(node->input(0)->origin()); - jlm::rvsdg::remove(node); } } } diff --git a/jlm/hls/backend/rvsdg2rhls/alloca-conv.cpp b/jlm/hls/backend/rvsdg2rhls/alloca-conv.cpp index e26543080..1579374b7 100644 --- a/jlm/hls/backend/rvsdg2rhls/alloca-conv.cpp +++ b/jlm/hls/backend/rvsdg2rhls/alloca-conv.cpp @@ -127,99 +127,101 @@ alloca_conv(rvsdg::Region * region) alloca_conv(structnode->subregion(n)); } } - else if (auto po = dynamic_cast(&(node->GetOperation()))) + else if (auto simpleNode = dynamic_cast(node)) { - // ensure that the size is one - JLM_ASSERT(node->ninputs() == 1); - auto constant_output = dynamic_cast(node->input(0)->origin()); - JLM_ASSERT(constant_output); - auto constant_operation = dynamic_cast( - &constant_output->node()->GetOperation()); - JLM_ASSERT(constant_operation); - JLM_ASSERT(constant_operation->Representation().to_uint() == 1); - // ensure that the alloca is an array type - auto at = std::dynamic_pointer_cast(po->ValueType()); - JLM_ASSERT(at); - // detect loads and stores attached to alloca - TraceAllocaUses ta(node->output(0)); - // create memory + response - auto mem_outs = local_mem_op::create(at, node->region()); - auto resp_outs = local_mem_resp_op::create(*mem_outs[0], ta.load_nodes.size()); - std::cout << "alloca converted " << at->debug_string() << std::endl; - // replace gep outputs (convert pointer to index calculation) - // replace loads and stores - std::vector load_addrs; - for (auto l : ta.load_nodes) + if (auto po = dynamic_cast(&(simpleNode->GetOperation()))) { - auto index = gep_to_index(l->input(0)->origin()); - auto response = route_response_rhls(l->region(), resp_outs.front()); - resp_outs.erase(resp_outs.begin()); - std::vector states; - for (size_t i = 1; i < l->ninputs(); ++i) + // ensure that the size is one + JLM_ASSERT(node->ninputs() == 1); + auto & constant_node = + rvsdg::AssertGetOwnerNode(*node->input(0)->origin()); + auto constant_operation = + util::AssertedCast(&constant_node.GetOperation()); + JLM_ASSERT(constant_operation->Representation().to_uint() == 1); + // ensure that the alloca is an array type + auto at = std::dynamic_pointer_cast(po->ValueType()); + JLM_ASSERT(at); + // detect loads and stores attached to alloca + TraceAllocaUses ta(node->output(0)); + // create memory + response + auto mem_outs = local_mem_op::create(at, node->region()); + auto resp_outs = local_mem_resp_op::create(*mem_outs[0], ta.load_nodes.size()); + std::cout << "alloca converted " << at->debug_string() << std::endl; + // replace gep outputs (convert pointer to index calculation) + // replace loads and stores + std::vector load_addrs; + for (auto l : ta.load_nodes) { - states.push_back(l->input(i)->origin()); + auto index = gep_to_index(l->input(0)->origin()); + auto response = route_response_rhls(l->region(), resp_outs.front()); + resp_outs.erase(resp_outs.begin()); + std::vector states; + for (size_t i = 1; i < l->ninputs(); ++i) + { + states.push_back(l->input(i)->origin()); + } + auto load_outs = local_load_op::create(*index, states, *response); + auto nn = dynamic_cast(load_outs[0])->node(); + for (size_t i = 0; i < l->noutputs(); ++i) + { + l->output(i)->divert_users(nn->output(i)); + } + remove(l); + auto addr = route_request_rhls(node->region(), load_outs.back()); + load_addrs.push_back(addr); } - auto load_outs = local_load_op::create(*index, states, *response); - auto nn = dynamic_cast(load_outs[0])->node(); - for (size_t i = 0; i < l->noutputs(); ++i) + std::vector store_operands; + for (auto s : ta.store_nodes) { - l->output(i)->divert_users(nn->output(i)); + auto index = gep_to_index(s->input(0)->origin()); + std::vector states; + for (size_t i = 2; i < s->ninputs(); ++i) + { + states.push_back(s->input(i)->origin()); + } + auto store_outs = local_store_op::create(*index, *s->input(1)->origin(), states); + auto nn = dynamic_cast(store_outs[0])->node(); + for (size_t i = 0; i < s->noutputs(); ++i) + { + s->output(i)->divert_users(nn->output(i)); + } + remove(s); + auto addr = route_request_rhls(node->region(), store_outs[store_outs.size() - 2]); + auto data = route_request_rhls(node->region(), store_outs.back()); + store_operands.push_back(addr); + store_operands.push_back(data); } - remove(l); - auto addr = route_request_rhls(node->region(), load_outs.back()); - load_addrs.push_back(addr); - } - std::vector store_operands; - for (auto s : ta.store_nodes) - { - auto index = gep_to_index(s->input(0)->origin()); - std::vector states; - for (size_t i = 2; i < s->ninputs(); ++i) + // TODO: ensure that loads/stores are either alloca or global, never both + // TODO: ensure that loads/stores have same width and alignment and geps can be merged - + // otherwise slice? create request + auto req_outs = local_mem_req_op::create(*mem_outs[1], load_addrs, store_operands); + + // remove alloca from memstate merge + // TODO: handle general case of other nodes getting state edge without a merge + JLM_ASSERT(node->output(1)->nusers() == 1); + auto merge_in = *node->output(1)->begin(); + auto merge_node = rvsdg::TryGetOwnerNode(*merge_in); + if (dynamic_cast(&merge_node->GetOperation())) { - states.push_back(s->input(i)->origin()); + // merge after alloca -> remove merge + JLM_ASSERT(merge_node->ninputs() == 2); + auto other_index = merge_in->index() ? 0 : 1; + merge_node->output(0)->divert_users(merge_node->input(other_index)->origin()); + jlm::rvsdg::remove(merge_node); } - auto store_outs = local_store_op::create(*index, *s->input(1)->origin(), states); - auto nn = dynamic_cast(store_outs[0])->node(); - for (size_t i = 0; i < s->noutputs(); ++i) + else { - s->output(i)->divert_users(nn->output(i)); + // TODO: fix this properly by adding a state edge to the LambdaEntryMemState and routing + // it to the region + JLM_ASSERT(false); } - remove(s); - auto addr = route_request_rhls(node->region(), store_outs[store_outs.size() - 2]); - auto data = route_request_rhls(node->region(), store_outs.back()); - store_operands.push_back(addr); - store_operands.push_back(data); - } - // TODO: ensure that loads/stores are either alloca or global, never both - // TODO: ensure that loads/stores have same width and alignment and geps can be merged - - // otherwise slice? create request - auto req_outs = local_mem_req_op::create(*mem_outs[1], load_addrs, store_operands); - // remove alloca from memstate merge - // TODO: handle general case of other nodes getting state edge without a merge - JLM_ASSERT(node->output(1)->nusers() == 1); - auto merge_in = *node->output(1)->begin(); - auto merge_node = rvsdg::TryGetOwnerNode(*merge_in); - if (dynamic_cast(&merge_node->GetOperation())) - { - // merge after alloca -> remove merge - JLM_ASSERT(merge_node->ninputs() == 2); - auto other_index = merge_in->index() ? 0 : 1; - merge_node->output(0)->divert_users(merge_node->input(other_index)->origin()); - jlm::rvsdg::remove(merge_node); + // TODO: run dne to + // remove loads/stores + // remove geps + // remove alloca pointer users + // remove alloca } - else - { - // TODO: fix this properly by adding a state edge to the LambdaEntryMemState and routing it - // to the region - JLM_ASSERT(false); - } - - // TODO: run dne to - // remove loads/stores - // remove geps - // remove alloca pointer users - // remove alloca } } } diff --git a/jlm/hls/backend/rvsdg2rhls/instrument-ref.cpp b/jlm/hls/backend/rvsdg2rhls/instrument-ref.cpp index 0c103df02..252c1be63 100644 --- a/jlm/hls/backend/rvsdg2rhls/instrument-ref.cpp +++ b/jlm/hls/backend/rvsdg2rhls/instrument-ref.cpp @@ -163,89 +163,93 @@ instrument_ref( allocaFunctionType); } } - else if ( - auto loadOp = - dynamic_cast(&(node->GetOperation()))) + else if (auto simpleNode = dynamic_cast(node)) { - auto addr = node->input(0)->origin(); - JLM_ASSERT(rvsdg::is(addr->Type())); - size_t bitWidth = BaseHLS::JlmSize(&*loadOp->GetLoadedType()); - int log2Bytes = log2(bitWidth / 8); - auto & widthNode = llvm::IntegerConstantOperation::Create(*region, 64, log2Bytes); - - // Does this IF make sense now when the void_ptr doesn't have a type? - if (*addr->Type() != *void_ptr) - { - addr = jlm::llvm::bitcast_op::create(addr, void_ptr); - } - auto memstate = node->input(1)->origin(); - auto callOp = jlm::llvm::CallOperation::Create( - load_func, - loadFunctionType, - { addr, widthNode.output(0), ioState, memstate }); - // Divert the memory state of the load to the new memstate from the call operation - node->input(1)->divert_to(callOp[1]); - } - else if (auto ao = dynamic_cast(&(node->GetOperation()))) - { - // ensure that the size is one - JLM_ASSERT(node->ninputs() == 1); - auto constant_output = dynamic_cast(node->input(0)->origin()); - JLM_ASSERT(constant_output); - auto constant_operation = dynamic_cast( - &constant_output->node()->GetOperation()); - JLM_ASSERT(constant_operation); - JLM_ASSERT(constant_operation->Representation().to_uint() == 1); - jlm::rvsdg::output * addr = node->output(0); - // ensure that the alloca is an array type - JLM_ASSERT(jlm::rvsdg::is(addr->Type())); - auto at = dynamic_cast(&ao->value_type()); - JLM_ASSERT(at); - auto & sizeNode = - llvm::IntegerConstantOperation::Create(*region, 64, BaseHLS::JlmSize(at) / 8); - - // Does this IF make sense now when the void_ptr doesn't have a type? - if (*addr->Type() != *void_ptr) - { - addr = jlm::llvm::bitcast_op::create(addr, void_ptr); - } - std::vector old_users(node->output(1)->begin(), node->output(1)->end()); - auto memstate = node->output(1); - auto callOp = jlm::llvm::CallOperation::Create( - alloca_func, - allocaFunctionType, - { addr, sizeNode.output(0), ioState, memstate }); - for (auto ou : old_users) + if (auto loadOp = dynamic_cast( + &(simpleNode->GetOperation()))) { + auto addr = node->input(0)->origin(); + JLM_ASSERT(std::dynamic_pointer_cast(addr->Type())); + size_t bitWidth = BaseHLS::JlmSize(&*loadOp->GetLoadedType()); + int log2Bytes = log2(bitWidth / 8); + auto & widthNode = llvm::IntegerConstantOperation::Create(*region, 64, log2Bytes); + + // Does this IF make sense now when the void_ptr doesn't have a type? + if (*addr->Type() != *void_ptr) + { + addr = jlm::llvm::bitcast_op::create(addr, void_ptr); + } + auto memstate = node->input(1)->origin(); + auto callOp = jlm::llvm::CallOperation::Create( + load_func, + loadFunctionType, + { addr, widthNode.output(0), ioState, memstate }); // Divert the memory state of the load to the new memstate from the call operation - ou->divert_to(callOp[1]); + node->input(1)->divert_to(callOp[1]); } - } - else if ( - auto so = - dynamic_cast(&(node->GetOperation()))) - { - auto addr = node->input(0)->origin(); - JLM_ASSERT(rvsdg::is(addr->Type())); - auto bitWidth = JlmSize(&so->GetStoredType()); - int log2Bytes = log2(bitWidth / 8); - auto & widthNode = llvm::IntegerConstantOperation::Create(*region, 64, log2Bytes); - - // Does this IF make sense now when the void_ptr doesn't have a type? - if (*addr->Type() != *void_ptr) + else if (auto ao = dynamic_cast(&(simpleNode->GetOperation()))) { - addr = jlm::llvm::bitcast_op::create(addr, void_ptr); + // ensure that the size is one + JLM_ASSERT(node->ninputs() == 1); + auto & constant_node = + rvsdg::AssertGetOwnerNode(*node->input(0)->origin()); + auto constant_operation = + util::AssertedCast(&constant_node.GetOperation()); + JLM_ASSERT(constant_operation->Representation().to_uint() == 1); + jlm::rvsdg::output * addr = node->output(0); + // ensure that the alloca is an array type + auto pt = std::dynamic_pointer_cast(addr->Type()); + JLM_ASSERT(pt); + auto at = dynamic_cast(&ao->value_type()); + JLM_ASSERT(at); + auto & sizeNode = + llvm::IntegerConstantOperation::Create(*region, 64, BaseHLS::JlmSize(at) / 8); + + // Does this IF make sense now when the void_ptr doesn't have a type? + if (*addr->Type() != *void_ptr) + { + addr = jlm::llvm::bitcast_op::create(addr, void_ptr); + } + std::vector old_users( + node->output(1)->begin(), + node->output(1)->end()); + auto memstate = node->output(1); + auto callOp = jlm::llvm::CallOperation::Create( + alloca_func, + allocaFunctionType, + { addr, sizeNode.output(0), ioState, memstate }); + for (auto ou : old_users) + { + // Divert the memory state of the load to the new memstate from the call operation + ou->divert_to(callOp[1]); + } } - auto memstate = node->output(0); - std::vector oldUsers(memstate->begin(), memstate->end()); - auto callOp = jlm::llvm::CallOperation::Create( - store_func, - storeFunctionType, - { addr, widthNode.output(0), ioState, memstate }); - // Divert the memory state after the store to the new memstate from the call operation - for (auto user : oldUsers) + else if ( + auto so = dynamic_cast( + &(simpleNode->GetOperation()))) { - user->divert_to(callOp[1]); + auto addr = node->input(0)->origin(); + JLM_ASSERT(std::dynamic_pointer_cast(addr->Type())); + auto bitWidth = JlmSize(&so->GetStoredType()); + int log2Bytes = log2(bitWidth / 8); + auto & widthNode = llvm::IntegerConstantOperation::Create(*region, 64, log2Bytes); + + // Does this IF make sense now when the void_ptr doesn't have a type? + if (*addr->Type() != *void_ptr) + { + addr = jlm::llvm::bitcast_op::create(addr, void_ptr); + } + auto memstate = node->output(0); + std::vector oldUsers(memstate->begin(), memstate->end()); + auto callOp = jlm::llvm::CallOperation::Create( + store_func, + storeFunctionType, + { addr, widthNode.output(0), ioState, memstate }); + // Divert the memory state after the store to the new memstate from the call operation + for (auto user : oldUsers) + { + user->divert_to(callOp[1]); + } } } } diff --git a/jlm/hls/backend/rvsdg2rhls/mem-queue.cpp b/jlm/hls/backend/rvsdg2rhls/mem-queue.cpp index df624545e..119b45b8a 100644 --- a/jlm/hls/backend/rvsdg2rhls/mem-queue.cpp +++ b/jlm/hls/backend/rvsdg2rhls/mem-queue.cpp @@ -507,7 +507,7 @@ jlm::hls::mem_queue(jlm::rvsdg::Region * region) // Check if there exists a memory state splitter if (state_arg->nusers() == 1) { - auto entryNode = rvsdg::TryGetOwnerNode(**state_arg->begin()); + auto entryNode = rvsdg::TryGetOwnerNode(**state_arg->begin()); if (jlm::rvsdg::is( entryNode->GetOperation())) { diff --git a/jlm/hls/backend/rvsdg2rhls/remove-redundant-buf.cpp b/jlm/hls/backend/rvsdg2rhls/remove-redundant-buf.cpp index 883463026..5918e6fb7 100644 --- a/jlm/hls/backend/rvsdg2rhls/remove-redundant-buf.cpp +++ b/jlm/hls/backend/rvsdg2rhls/remove-redundant-buf.cpp @@ -58,9 +58,9 @@ remove_redundant_buf(rvsdg::Region * region) remove_redundant_buf(structnode->subregion(n)); } } - else if (dynamic_cast(node)) + else if (auto simplenode = dynamic_cast(node)) { - if (auto buf = dynamic_cast(&node->GetOperation())) + if (auto buf = dynamic_cast(&simplenode->GetOperation())) { if (std::dynamic_pointer_cast(buf->argument(0))) { diff --git a/jlm/hls/backend/rvsdg2rhls/rhls-dne.cpp b/jlm/hls/backend/rvsdg2rhls/rhls-dne.cpp index de4776c32..3e57cb997 100644 --- a/jlm/hls/backend/rvsdg2rhls/rhls-dne.cpp +++ b/jlm/hls/backend/rvsdg2rhls/rhls-dne.cpp @@ -126,7 +126,7 @@ remove_unused_loop_inputs(loop_node * ln) } bool -dead_spec_gamma(rvsdg::Node * dmux_node) +dead_spec_gamma(rvsdg::SimpleNode * dmux_node) { auto mux_op = dynamic_cast(&dmux_node->GetOperation()); JLM_ASSERT(mux_op); @@ -152,7 +152,7 @@ dead_spec_gamma(rvsdg::Node * dmux_node) } bool -dead_nonspec_gamma(rvsdg::Node * ndmux_node) +dead_nonspec_gamma(rvsdg::SimpleNode * ndmux_node) { auto mux_op = dynamic_cast(&ndmux_node->GetOperation()); JLM_ASSERT(mux_op); @@ -194,7 +194,7 @@ dead_nonspec_gamma(rvsdg::Node * ndmux_node) } bool -dead_loop(rvsdg::Node * ndmux_node) +dead_loop(rvsdg::SimpleNode * ndmux_node) { auto mux_op = dynamic_cast(&ndmux_node->GetOperation()); JLM_ASSERT(mux_op); @@ -282,19 +282,23 @@ dne(rvsdg::Region * sr) { if (!node->has_users()) { - if (dynamic_cast(&node->GetOperation())) + if (auto simpleNode = dynamic_cast(node)) { - // TODO: fix this once memory connections are explicit - continue; - } - else if (dynamic_cast(&node->GetOperation())) - { - continue; - } - else if (dynamic_cast(&node->GetOperation())) - { - // TODO: fix - this scenario has only stores and should just be optimized away completely - continue; + if (dynamic_cast(&simpleNode->GetOperation())) + { + // TODO: fix this once memory connections are explicit + continue; + } + else if (dynamic_cast(&simpleNode->GetOperation())) + { + continue; + } + else if (dynamic_cast(&simpleNode->GetOperation())) + { + // TODO: fix - this scenario has only stores and should just be optimized away + // completely + continue; + } } remove(node); changed = true; @@ -307,15 +311,18 @@ dne(rvsdg::Region * sr) changed |= remove_loop_passthrough(ln); changed |= dne(ln->subregion()); } - else if (auto mux = dynamic_cast(&node->GetOperation())) + else if (auto simpleNode = dynamic_cast(node)) { - if (mux->discarding) - { - changed |= dead_spec_gamma(node); - } - else + if (auto mux = dynamic_cast(&simpleNode->GetOperation())) { - changed |= dead_nonspec_gamma(node) || dead_loop(node); + if (mux->discarding) + { + changed |= dead_spec_gamma(simpleNode); + } + else + { + changed |= dead_nonspec_gamma(simpleNode) || dead_loop(simpleNode); + } } } } diff --git a/jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp b/jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp index f19f4c489..e7961509f 100644 --- a/jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp +++ b/jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp @@ -150,30 +150,33 @@ inline_calls(rvsdg::Region * region) inline_calls(structnode->subregion(n)); } } - else if (dynamic_cast(&(node->GetOperation()))) + else if (auto simpleNode = dynamic_cast(node)) { - auto traced = jlm::hls::trace_call(node->input(0)); - auto so = dynamic_cast(traced); - if (!so) + if (dynamic_cast(&(simpleNode->GetOperation()))) { - if (auto graphImport = dynamic_cast(traced)) + auto traced = jlm::hls::trace_call(node->input(0)); + auto so = dynamic_cast(traced); + if (!so) { - if (graphImport->Name().rfind("decouple_", 0) == 0) + if (auto graphImport = dynamic_cast(traced)) { - // can't inline pseudo functions used for decoupling - continue; + if (graphImport->Name().rfind("decouple_", 0) == 0) + { + // can't inline pseudo functions used for decoupling + continue; + } + throw jlm::util::error("can not inline external function " + graphImport->Name()); } - throw jlm::util::error("can not inline external function " + graphImport->Name()); } + JLM_ASSERT(dynamic_cast(so->node())); + auto ln = dynamic_cast(traced)->node(); + llvm::inlineCall( + dynamic_cast(node), + dynamic_cast(ln)); + // restart for this region + inline_calls(region); + return; } - JLM_ASSERT(rvsdg::is(so->node())); - auto ln = dynamic_cast(traced)->node(); - llvm::inlineCall( - dynamic_cast(node), - dynamic_cast(ln)); - // restart for this region - inline_calls(region); - return; } } } @@ -192,57 +195,60 @@ convert_alloca(rvsdg::Region * region) convert_alloca(structnode->subregion(n)); } } - else if (auto po = dynamic_cast(&(node->GetOperation()))) + else if (auto simpleNode = dynamic_cast(node)) { - auto rr = ®ion->graph()->GetRootRegion(); - auto delta_name = jlm::util::strfmt("hls_alloca_", alloca_cnt++); - auto delta_type = llvm::PointerType::Create(); - std::cout << "alloca " << delta_name << ": " << po->value_type().debug_string() << "\n"; - auto db = llvm::delta::node::Create( - rr, - std::static_pointer_cast(po->ValueType()), - delta_name, - llvm::linkage::external_linkage, - "", - false); - // create zero constant of allocated type - jlm::rvsdg::output * cout; - if (auto bt = dynamic_cast(&po->value_type())) + if (auto po = dynamic_cast(&(simpleNode->GetOperation()))) { - cout = llvm::IntegerConstantOperation::Create( - *db->subregion(), - bt->Representation().nbits(), - 0) - .output(0); - } - else - { - cout = llvm::ConstantAggregateZeroOperation::Create(*db->subregion(), po->ValueType()); - } - auto delta = db->finalize(cout); - jlm::llvm::GraphExport::Create(*delta, delta_name); - auto delta_local = route_to_region_rvsdg(delta, region); - node->output(0)->divert_users(delta_local); - // TODO: check that the input to alloca is a bitconst 1 - // TODO: handle general case of other nodes getting state edge without a merge - JLM_ASSERT(node->output(1)->nusers() == 1); - auto mux_in = *node->output(1)->begin(); - auto mux_node = rvsdg::TryGetOwnerNode(*mux_in); - if (dynamic_cast(&mux_node->GetOperation())) - { - // merge after alloca -> remove merge - JLM_ASSERT(mux_node->ninputs() == 2); - auto other_index = mux_in->index() ? 0 : 1; - mux_node->output(0)->divert_users(mux_node->input(other_index)->origin()); - jlm::rvsdg::remove(mux_node); - } - else - { - // TODO: fix this properly by adding a state edge to the LambdaEntryMemState and routing it - // to the region - JLM_ASSERT(false); + auto rr = ®ion->graph()->GetRootRegion(); + auto delta_name = jlm::util::strfmt("hls_alloca_", alloca_cnt++); + auto delta_type = llvm::PointerType::Create(); + std::cout << "alloca " << delta_name << ": " << po->value_type().debug_string() << "\n"; + auto db = llvm::delta::node::Create( + rr, + std::static_pointer_cast(po->ValueType()), + delta_name, + llvm::linkage::external_linkage, + "", + false); + // create zero constant of allocated type + jlm::rvsdg::output * cout; + if (auto bt = dynamic_cast(&po->value_type())) + { + cout = llvm::IntegerConstantOperation::Create( + *db->subregion(), + bt->Representation().nbits(), + 0) + .output(0); + } + else + { + cout = llvm::ConstantAggregateZeroOperation::Create(*db->subregion(), po->ValueType()); + } + auto delta = db->finalize(cout); + jlm::llvm::GraphExport::Create(*delta, delta_name); + auto delta_local = route_to_region_rvsdg(delta, region); + node->output(0)->divert_users(delta_local); + // TODO: check that the input to alloca is a bitconst 1 + // TODO: handle general case of other nodes getting state edge without a merge + JLM_ASSERT(node->output(1)->nusers() == 1); + auto mux_in = *node->output(1)->begin(); + auto mux_node = rvsdg::TryGetOwnerNode(*mux_in); + if (dynamic_cast(&mux_node->GetOperation())) + { + // merge after alloca -> remove merge + JLM_ASSERT(mux_node->ninputs() == 2); + auto other_index = mux_in->index() ? 0 : 1; + mux_node->output(0)->divert_users(mux_node->input(other_index)->origin()); + jlm::rvsdg::remove(mux_node); + } + else + { + // TODO: fix this properly by adding a state edge to the LambdaEntryMemState and routing + // it to the region + JLM_ASSERT(false); + } + jlm::rvsdg::remove(node); } - jlm::rvsdg::remove(node); } } } diff --git a/jlm/hls/ir/hls.cpp b/jlm/hls/ir/hls.cpp index fb0b247b1..3310fc8ce 100644 --- a/jlm/hls/ir/hls.cpp +++ b/jlm/hls/ir/hls.cpp @@ -95,7 +95,7 @@ loop_node::AddLoopVar(jlm::rvsdg::output * origin, jlm::rvsdg::output ** buffer) return output; } -[[nodiscard]] const rvsdg::Operation & +[[nodiscard]] const rvsdg::StructuralOperation & loop_node::GetOperation() const noexcept { static const loop_op singleton; diff --git a/jlm/hls/ir/hls.hpp b/jlm/hls/ir/hls.hpp index 69db9cddd..a94f98e0c 100644 --- a/jlm/hls/ir/hls.hpp +++ b/jlm/hls/ir/hls.hpp @@ -755,7 +755,7 @@ class loop_node final : public rvsdg::StructuralNode jlm::rvsdg::node_output * _predicate_buffer; public: - [[nodiscard]] const rvsdg::Operation & + [[nodiscard]] const rvsdg::StructuralOperation & GetOperation() const noexcept override; static loop_node * diff --git a/jlm/hls/opt/IOBarrierRemoval.cpp b/jlm/hls/opt/IOBarrierRemoval.cpp index 41711d5af..a531d63ab 100644 --- a/jlm/hls/opt/IOBarrierRemoval.cpp +++ b/jlm/hls/opt/IOBarrierRemoval.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include namespace jlm::hls @@ -33,11 +34,13 @@ IOBarrierRemoval::RemoveIOBarrierFromRegion(rvsdg::Region & region) RemoveIOBarrierFromRegion(*structuralNode->subregion(n)); } } - - // Render all IOBarrier nodes dead - if (rvsdg::is(&node)) + else if (auto simpleNode = dynamic_cast(&node)) { - node.output(0)->divert_users(node.input(0)->origin()); + // Render all IOBarrier nodes dead + if (rvsdg::is(simpleNode)) + { + node.output(0)->divert_users(node.input(0)->origin()); + } } } diff --git a/jlm/hls/opt/cne.cpp b/jlm/hls/opt/cne.cpp index 257055c0f..5c6c58764 100644 --- a/jlm/hls/opt/cne.cpp +++ b/jlm/hls/opt/cne.cpp @@ -272,20 +272,24 @@ congruent(jlm::rvsdg::output * o1, jlm::rvsdg::output * o2, vset & vs, cnectx & } } - if (jlm::rvsdg::is(n1) && jlm::rvsdg::is(n2) - && n1->GetOperation() == n2->GetOperation() && n1->ninputs() == n2->ninputs() - && o1->index() == o2->index()) + if (auto s1 = rvsdg::TryGetOwnerNode(*o1)) { - for (size_t n = 0; n < n1->ninputs(); n++) + if (auto s2 = rvsdg::TryGetOwnerNode(*o2)) { - auto origin1 = n1->input(n)->origin(); - auto origin2 = n2->input(n)->origin(); - if (!congruent(origin1, origin2, vs, ctx)) - return false; + if (s1->GetOperation() == s2->GetOperation() && s1->ninputs() == s2->ninputs() + && o1->index() == o2->index()) + { + for (size_t n = 0; n < s1->ninputs(); n++) + { + auto origin1 = s1->input(n)->origin(); + auto origin2 = s2->input(n)->origin(); + if (!congruent(origin1, origin2, vs, ctx)) + return false; + } + return true; + } } - return true; } - return false; } @@ -459,7 +463,9 @@ mark(const jlm::rvsdg::SimpleNode * node, cnectx & ctx) { for (const auto & other : node->region()->TopNodes()) { - if (&other != node && node->GetOperation() == other.GetOperation()) + auto otherSimpleNode = dynamic_cast(&other); + if (otherSimpleNode && &other != node + && node->GetOperation() == otherSimpleNode->GetOperation()) { ctx.mark(node, &other); break; @@ -473,8 +479,7 @@ mark(const jlm::rvsdg::SimpleNode * node, cnectx & ctx) { for (const auto & user : *origin) { - auto ni = dynamic_cast(user); - auto other = ni ? ni->node() : nullptr; + auto other = rvsdg::TryGetOwnerNode(*user); if (!other || other == node || other->GetOperation() != node->GetOperation() || other->ninputs() != node->ninputs()) continue; diff --git a/jlm/llvm/backend/RvsdgToIpGraphConverter.cpp b/jlm/llvm/backend/RvsdgToIpGraphConverter.cpp index a1c2e3b99..dd830c695 100644 --- a/jlm/llvm/backend/RvsdgToIpGraphConverter.cpp +++ b/jlm/llvm/backend/RvsdgToIpGraphConverter.cpp @@ -162,7 +162,7 @@ RvsdgToIpGraphConverter::CreateInitialization(const delta::node & deltaNode) operands.push_back(Context_->GetVariable(node->input(n)->origin())); // convert node to tac - auto & op = *static_cast(&node->GetOperation()); + auto & op = util::AssertedCast(node)->GetOperation(); tacs.push_back(tac::create(op, operands)); Context_->InsertVariable(output, tacs.back()->result(0)); } diff --git a/jlm/llvm/ir/operators/operators.cpp b/jlm/llvm/ir/operators/operators.cpp index 370c8221d..c4bf408a2 100644 --- a/jlm/llvm/ir/operators/operators.cpp +++ b/jlm/llvm/ir/operators/operators.cpp @@ -408,7 +408,8 @@ ZExtOperation::reduce_operand(rvsdg::unop_reduction_path_t path, rvsdg::output * { if (path == rvsdg::unop_reduction_constant) { - auto c = static_cast(&producer(operand)->GetOperation()); + auto c = util::AssertedCast( + &util::AssertedCast(producer(operand))->GetOperation()); return create_bitconstant( rvsdg::TryGetOwnerNode(*operand)->region(), c->value().zext(ndstbits() - nsrcbits())); diff --git a/jlm/llvm/ir/operators/sext.cpp b/jlm/llvm/ir/operators/sext.cpp index 52707f72c..c3cc76bdd 100644 --- a/jlm/llvm/ir/operators/sext.cpp +++ b/jlm/llvm/ir/operators/sext.cpp @@ -118,7 +118,8 @@ sext_op::reduce_operand(rvsdg::unop_reduction_path_t path, rvsdg::output * opera { if (path == rvsdg::unop_reduction_constant) { - auto c = static_cast(&producer(operand)->GetOperation()); + auto c = util::AssertedCast( + &util::AssertedCast(producer(operand))->GetOperation()); return create_bitconstant(operand->region(), c->value().sext(ndstbits() - nsrcbits())); } diff --git a/jlm/llvm/opt/InvariantValueRedirection.cpp b/jlm/llvm/opt/InvariantValueRedirection.cpp index bfd39ca99..a4aee740e 100644 --- a/jlm/llvm/opt/InvariantValueRedirection.cpp +++ b/jlm/llvm/opt/InvariantValueRedirection.cpp @@ -86,11 +86,17 @@ InvariantValueRedirection::RedirectInRootRegion(rvsdg::Graph & rvsdg) // Nothing needs to be done. // Delta nodes are irrelevant for invariant value redirection. } - else if ( - is(node->GetOperation()) - || is(node->GetOperation())) + else if (auto simpleNode = dynamic_cast(node)) { - // Nothing needs to be done. + if (is(simpleNode->GetOperation()) + || is(simpleNode->GetOperation())) + { + // Nothing needs to be done. + } + else + { + JLM_UNREACHABLE("Unhandled node type."); + } } else { @@ -125,9 +131,12 @@ InvariantValueRedirection::RedirectInRegion(rvsdg::Region & region) RedirectInSubregions(*thetaNode); RedirectThetaOutputs(*thetaNode); } - else if (is(&node)) + else if (auto simpleNode = dynamic_cast(&node)) { - RedirectCallOutputs(*util::AssertedCast(&node)); + if (is(simpleNode)) + { + RedirectCallOutputs(*util::AssertedCast(&node)); + } } } } diff --git a/jlm/llvm/opt/alias-analyses/TopDownModRefEliminator.cpp b/jlm/llvm/opt/alias-analyses/TopDownModRefEliminator.cpp index 957aae803..cef90f8fb 100644 --- a/jlm/llvm/opt/alias-analyses/TopDownModRefEliminator.cpp +++ b/jlm/llvm/opt/alias-analyses/TopDownModRefEliminator.cpp @@ -504,11 +504,18 @@ TopDownModRefEliminator::EliminateTopDownRootRegion(rvsdg::Region & region) { // Nothing needs to be done. } - else if ( - is(node->GetOperation()) - || is(node->GetOperation())) + else if (auto simpleNode = dynamic_cast(node)) { - // Nothing needs to be done. + if (is(simpleNode->GetOperation()) + || is(simpleNode->GetOperation())) + + { + // Nothing needs to be done. + } + else + { + JLM_UNREACHABLE("Unhandled node type!"); + } } else { diff --git a/jlm/llvm/opt/cne.cpp b/jlm/llvm/opt/cne.cpp index cf0ded29b..f6f5149c6 100644 --- a/jlm/llvm/opt/cne.cpp +++ b/jlm/llvm/opt/cne.cpp @@ -261,18 +261,23 @@ congruent(jlm::rvsdg::output * o1, jlm::rvsdg::output * o2, vset & vs, cnectx & } } - if (jlm::rvsdg::is(n1) && jlm::rvsdg::is(n2) - && n1->GetOperation() == n2->GetOperation() && n1->ninputs() == n2->ninputs() - && o1->index() == o2->index()) + if (auto simple1 = rvsdg::TryGetOwnerNode(*o1)) { - for (size_t n = 0; n < n1->ninputs(); n++) + if (auto simple2 = rvsdg::TryGetOwnerNode(*o2)) { - auto origin1 = n1->input(n)->origin(); - auto origin2 = n2->input(n)->origin(); - if (!congruent(origin1, origin2, vs, ctx)) - return false; + if (simple1->GetOperation() == simple2->GetOperation() + && simple1->ninputs() == simple2->ninputs() && o1->index() == o2->index()) + { + for (size_t n = 0; n < n1->ninputs(); n++) + { + auto origin1 = n1->input(n)->origin(); + auto origin2 = n2->input(n)->origin(); + if (!congruent(origin1, origin2, vs, ctx)) + return false; + } + return true; + } } - return true; } return false; @@ -423,7 +428,8 @@ mark(const jlm::rvsdg::SimpleNode * node, cnectx & ctx) { for (const auto & other : node->region()->TopNodes()) { - if (&other != node && node->GetOperation() == other.GetOperation()) + auto otherSimple = dynamic_cast(&other); + if (&other != node && otherSimple && node->GetOperation() == otherSimple->GetOperation()) { ctx.mark(node, &other); break; @@ -437,8 +443,7 @@ mark(const jlm::rvsdg::SimpleNode * node, cnectx & ctx) { for (const auto & user : *origin) { - auto ni = dynamic_cast(user); - auto other = ni ? ni->node() : nullptr; + auto other = rvsdg::TryGetOwnerNode(*user); if (!other || other == node || other->GetOperation() != node->GetOperation() || other->ninputs() != node->ninputs()) continue; diff --git a/jlm/llvm/opt/push.cpp b/jlm/llvm/opt/push.cpp index 9593ab7cc..b4eb4b705 100644 --- a/jlm/llvm/opt/push.cpp +++ b/jlm/llvm/opt/push.cpp @@ -322,7 +322,7 @@ is_movable_store(rvsdg::Node * node) } static void -pushout_store(rvsdg::Node * storenode) +pushout_store(rvsdg::SimpleNode * storenode) { JLM_ASSERT(dynamic_cast(storenode->region()->node())); JLM_ASSERT(jlm::rvsdg::is(storenode) && is_movable_store(storenode)); diff --git a/jlm/llvm/opt/reduction.cpp b/jlm/llvm/opt/reduction.cpp index 8c43b98c4..b6f202a83 100644 --- a/jlm/llvm/opt/reduction.cpp +++ b/jlm/llvm/opt/reduction.cpp @@ -84,9 +84,9 @@ NodeReduction::ReduceNodesInRegion(rvsdg::Region & region) { reductionPerformed |= ReduceStructuralNode(*structuralNode); } - else if (rvsdg::is(node)) + else if (const auto simpleNode = dynamic_cast(node)) { - reductionPerformed |= ReduceSimpleNode(*node); + reductionPerformed |= ReduceSimpleNode(*simpleNode); } else { @@ -144,7 +144,7 @@ NodeReduction::ReduceGammaNode(rvsdg::StructuralNode & gammaNode) } bool -NodeReduction::ReduceSimpleNode(rvsdg::Node & simpleNode) +NodeReduction::ReduceSimpleNode(rvsdg::SimpleNode & simpleNode) { if (is(&simpleNode)) { @@ -169,7 +169,7 @@ NodeReduction::ReduceSimpleNode(rvsdg::Node & simpleNode) } bool -NodeReduction::ReduceLoadNode(rvsdg::Node & simpleNode) +NodeReduction::ReduceLoadNode(rvsdg::SimpleNode & simpleNode) { JLM_ASSERT(is(&simpleNode)); @@ -177,7 +177,7 @@ NodeReduction::ReduceLoadNode(rvsdg::Node & simpleNode) } bool -NodeReduction::ReduceStoreNode(rvsdg::Node & simpleNode) +NodeReduction::ReduceStoreNode(rvsdg::SimpleNode & simpleNode) { JLM_ASSERT(is(&simpleNode)); @@ -185,7 +185,7 @@ NodeReduction::ReduceStoreNode(rvsdg::Node & simpleNode) } bool -NodeReduction::ReduceBinaryNode(rvsdg::Node & simpleNode) +NodeReduction::ReduceBinaryNode(rvsdg::SimpleNode & simpleNode) { JLM_ASSERT(is(&simpleNode)); diff --git a/jlm/llvm/opt/reduction.hpp b/jlm/llvm/opt/reduction.hpp index 4a7e53be8..e79e142a1 100644 --- a/jlm/llvm/opt/reduction.hpp +++ b/jlm/llvm/opt/reduction.hpp @@ -6,6 +6,7 @@ #ifndef JLM_LLVM_OPT_REDUCTION_HPP #define JLM_LLVM_OPT_REDUCTION_HPP +#include #include #include @@ -73,16 +74,16 @@ class NodeReduction final : public rvsdg::Transformation ReduceGammaNode(rvsdg::StructuralNode & gammaNode); [[nodiscard]] static bool - ReduceSimpleNode(rvsdg::Node & simpleNode); + ReduceSimpleNode(rvsdg::SimpleNode & simpleNode); [[nodiscard]] static bool - ReduceLoadNode(rvsdg::Node & simpleNode); + ReduceLoadNode(rvsdg::SimpleNode & simpleNode); [[nodiscard]] static bool - ReduceStoreNode(rvsdg::Node & simpleNode); + ReduceStoreNode(rvsdg::SimpleNode & simpleNode); [[nodiscard]] static bool - ReduceBinaryNode(rvsdg::Node & simpleNode); + ReduceBinaryNode(rvsdg::SimpleNode & simpleNode); static std::optional> NormalizeLoadNode( diff --git a/jlm/llvm/opt/unroll.hpp b/jlm/llvm/opt/unroll.hpp index 7aab200b1..33bdf3701 100644 --- a/jlm/llvm/opt/unroll.hpp +++ b/jlm/llvm/opt/unroll.hpp @@ -51,8 +51,8 @@ class unrollinfo final private: inline unrollinfo( - rvsdg::Node * cmpnode, - rvsdg::Node * armnode, + rvsdg::SimpleNode * cmpnode, + rvsdg::SimpleNode * armnode, rvsdg::output * idv, rvsdg::output * step, rvsdg::output * end) @@ -108,7 +108,7 @@ class unrollinfo final std::unique_ptr niterations() const noexcept; - rvsdg::Node * + rvsdg::SimpleNode * cmpnode() const noexcept { return cmpnode_; @@ -117,10 +117,10 @@ class unrollinfo final [[nodiscard]] const rvsdg::SimpleOperation & cmpoperation() const noexcept { - return *static_cast(&cmpnode()->GetOperation()); + return cmpnode()->GetOperation(); } - inline rvsdg::Node * + inline rvsdg::SimpleNode * armnode() const noexcept { return armnode_; @@ -129,7 +129,7 @@ class unrollinfo final [[nodiscard]] const rvsdg::SimpleOperation & armoperation() const noexcept { - return *static_cast(&armnode()->GetOperation()); + return armnode()->GetOperation(); } inline rvsdg::output * @@ -206,10 +206,9 @@ class unrollinfo final inline bool is_known(jlm::rvsdg::output * output) const noexcept { - auto p = producer(output); + auto p = dynamic_cast(producer(output)); if (!p) return false; - auto op = dynamic_cast(&p->GetOperation()); return op && op->value().is_known(); } @@ -220,14 +219,14 @@ class unrollinfo final if (!is_known(output)) return nullptr; - auto p = producer(output); - return &static_cast(&p->GetOperation())->value(); + auto p = util::AssertedCast(producer(output)); + return &util::AssertedCast(&p->GetOperation())->value(); } rvsdg::output * end_; rvsdg::output * step_; - rvsdg::Node * cmpnode_; - rvsdg::Node * armnode_; + rvsdg::SimpleNode * cmpnode_; + rvsdg::SimpleNode * armnode_; rvsdg::output * idv_; }; diff --git a/jlm/rvsdg/NodeNormalization.hpp b/jlm/rvsdg/NodeNormalization.hpp index f4ce375e5..f3437e366 100644 --- a/jlm/rvsdg/NodeNormalization.hpp +++ b/jlm/rvsdg/NodeNormalization.hpp @@ -43,7 +43,7 @@ NormalizeSequence( template bool -ReduceNode(const NodeNormalization & nodeNormalization, Node & node) +ReduceNode(const NodeNormalization & nodeNormalization, SimpleNode & node) { auto operation = util::AssertedCast(&node.GetOperation()); auto operands = rvsdg::operands(&node); diff --git a/jlm/rvsdg/binary.cpp b/jlm/rvsdg/binary.cpp index 4edb79265..6f88f503d 100644 --- a/jlm/rvsdg/binary.cpp +++ b/jlm/rvsdg/binary.cpp @@ -81,10 +81,19 @@ FlattenAssociativeBinaryOperation( if (node == nullptr) return false; - auto flattenedBinaryOperation = - dynamic_cast(&node->GetOperation()); - return node->GetOperation() == operation - || (flattenedBinaryOperation && flattenedBinaryOperation->bin_operation() == operation); + auto simpleNode = dynamic_cast(node); + if (simpleNode) + { + auto flattenedBinaryOperation = + dynamic_cast(&simpleNode->GetOperation()); + return simpleNode->GetOperation() == operation + || (flattenedBinaryOperation + && flattenedBinaryOperation->bin_operation() == operation); + } + else + { + return false; + } }); if (operands == newOperands) @@ -220,12 +229,15 @@ FlattenedBinaryOperation::reduce( { for (auto & node : TopDownTraverser(region)) { - if (is(node)) + if (auto simpleNode = dynamic_cast(node)) { - const auto op = static_cast(&node->GetOperation()); - auto output = op->reduce(reduction, operands(node)); - node->output(0)->divert_users(output); - remove(node); + auto op = dynamic_cast(&simpleNode->GetOperation()); + if (op) + { + auto output = op->reduce(reduction, operands(node)); + node->output(0)->divert_users(output); + remove(node); + } } else if (auto structnode = dynamic_cast(node)) { diff --git a/jlm/rvsdg/bitstring/bitoperation-classes.cpp b/jlm/rvsdg/bitstring/bitoperation-classes.cpp index 297e00ec2..58fe82166 100644 --- a/jlm/rvsdg/bitstring/bitoperation-classes.cpp +++ b/jlm/rvsdg/bitstring/bitoperation-classes.cpp @@ -27,7 +27,7 @@ bitunary_op::reduce_operand(unop_reduction_path_t path, jlm::rvsdg::output * arg { if (path == unop_reduction_constant) { - auto p = producer(arg); + auto p = static_cast(producer(arg)); auto & c = static_cast(p->GetOperation()); return create_bitconstant(p->region(), reduce_constant(c.value())); } @@ -57,8 +57,10 @@ bitbinary_op::reduce_operand_pair( { if (path == binop_reduction_constants) { - auto & c1 = static_cast(producer(arg1)->GetOperation()); - auto & c2 = static_cast(producer(arg2)->GetOperation()); + auto & c1 = static_cast( + static_cast(producer(arg1))->GetOperation()); + auto & c2 = static_cast( + static_cast(producer(arg2))->GetOperation()); return create_bitconstant(arg1->region(), reduce_constants(c1.value(), c2.value())); } @@ -73,12 +75,12 @@ bitcompare_op::can_reduce_operand_pair( const jlm::rvsdg::output * arg1, const jlm::rvsdg::output * arg2) const noexcept { - auto p = producer(arg1); + auto p = dynamic_cast(producer(arg1)); const bitconstant_op * c1_op = nullptr; if (p) c1_op = dynamic_cast(&p->GetOperation()); - p = producer(arg2); + p = dynamic_cast(producer(arg2)); const bitconstant_op * c2_op = nullptr; if (p) c2_op = dynamic_cast(&p->GetOperation()); diff --git a/jlm/rvsdg/bitstring/concat.cpp b/jlm/rvsdg/bitstring/concat.cpp index e6e0a7b96..5d58fd89e 100644 --- a/jlm/rvsdg/bitstring/concat.cpp +++ b/jlm/rvsdg/bitstring/concat.cpp @@ -58,8 +58,8 @@ bitconcat_op::can_reduce_operand_pair( const jlm::rvsdg::output * arg1, const jlm::rvsdg::output * arg2) const noexcept { - auto node1 = TryGetOwnerNode(*arg1); - auto node2 = TryGetOwnerNode(*arg2); + auto node1 = TryGetOwnerNode(*arg1); + auto node2 = TryGetOwnerNode(*arg2); if (!node1 || !node2) return binop_reduction_none; @@ -97,13 +97,13 @@ bitconcat_op::reduce_operand_pair( jlm::rvsdg::output * arg1, jlm::rvsdg::output * arg2) const { - auto node1 = static_cast(arg1)->node(); - auto node2 = static_cast(arg2)->node(); + auto & node1 = AssertGetOwnerNode(*arg1); + auto & node2 = AssertGetOwnerNode(*arg2); if (path == binop_reduction_constants) { - auto & arg1_constant = static_cast(node1->GetOperation()); - auto & arg2_constant = static_cast(node2->GetOperation()); + auto & arg1_constant = static_cast(node1.GetOperation()); + auto & arg2_constant = static_cast(node2.GetOperation()); bitvalue_repr bits(arg1_constant.value()); bits.Append(arg2_constant.value()); @@ -112,9 +112,9 @@ bitconcat_op::reduce_operand_pair( if (path == binop_reduction_merge) { - auto arg1_slice = static_cast(&node1->GetOperation()); - auto arg2_slice = static_cast(&node2->GetOperation()); - return jlm::rvsdg::bitslice(node1->input(0)->origin(), arg1_slice->low(), arg2_slice->high()); + auto arg1_slice = static_cast(&node1.GetOperation()); + auto arg2_slice = static_cast(&node2.GetOperation()); + return jlm::rvsdg::bitslice(node1.input(0)->origin(), arg1_slice->low(), arg2_slice->high()); /* FIXME: support sign bit */ } diff --git a/jlm/rvsdg/bitstring/slice.cpp b/jlm/rvsdg/bitstring/slice.cpp index 64e61943b..242cfd1fc 100644 --- a/jlm/rvsdg/bitstring/slice.cpp +++ b/jlm/rvsdg/bitstring/slice.cpp @@ -57,17 +57,17 @@ bitslice_op::reduce_operand(unop_reduction_path_t path, jlm::rvsdg::output * arg return arg; } - auto node = static_cast(arg)->node(); + auto & node = AssertGetOwnerNode(*arg); if (path == unop_reduction_narrow) { - auto op = static_cast(node->GetOperation()); - return jlm::rvsdg::bitslice(node->input(0)->origin(), low() + op.low(), high() + op.low()); + auto op = static_cast(node.GetOperation()); + return jlm::rvsdg::bitslice(node.input(0)->origin(), low() + op.low(), high() + op.low()); } if (path == unop_reduction_constant) { - auto op = static_cast(node->GetOperation()); + auto op = static_cast(node.GetOperation()); std::string s(&op.value()[0] + low(), high() - low()); return create_bitconstant(arg->region(), s.c_str()); } @@ -76,9 +76,9 @@ bitslice_op::reduce_operand(unop_reduction_path_t path, jlm::rvsdg::output * arg { size_t pos = 0, n; std::vector arguments; - for (n = 0; n < node->ninputs(); n++) + for (n = 0; n < node.ninputs(); n++) { - auto argument = node->input(n)->origin(); + auto argument = node.input(n)->origin(); size_t base = pos; size_t nbits = std::static_pointer_cast(argument->Type())->nbits(); pos = pos + nbits; diff --git a/jlm/rvsdg/control.cpp b/jlm/rvsdg/control.cpp index 1a61bc8a3..50e9a11b1 100644 --- a/jlm/rvsdg/control.cpp +++ b/jlm/rvsdg/control.cpp @@ -119,7 +119,8 @@ match_op::reduce_operand(unop_reduction_path_t path, jlm::rvsdg::output * arg) c { if (path == unop_reduction_constant) { - auto op = static_cast(producer(arg)->GetOperation()); + auto op = static_cast( + static_cast(producer(arg))->GetOperation()); return jlm::rvsdg::control_constant( arg->region(), nalternatives(), diff --git a/jlm/rvsdg/node.hpp b/jlm/rvsdg/node.hpp index ee4cc5e06..35a780718 100644 --- a/jlm/rvsdg/node.hpp +++ b/jlm/rvsdg/node.hpp @@ -588,9 +588,6 @@ class Node explicit Node(Region * region); - [[nodiscard]] virtual const Operation & - GetOperation() const noexcept = 0; - inline bool has_users() const noexcept { @@ -1030,16 +1027,6 @@ divert_users(Node * node, const std::vector & outputs) node->output(n)->divert_users(outputs[n]); } -template -static inline bool -is(const Node * node) noexcept -{ - if (!node) - return false; - - return is(node->GetOperation()); -} - Node * producer(const jlm::rvsdg::output * output) noexcept; diff --git a/jlm/rvsdg/simple-node.cpp b/jlm/rvsdg/simple-node.cpp index 86c21a212..d1f88eaeb 100644 --- a/jlm/rvsdg/simple-node.cpp +++ b/jlm/rvsdg/simple-node.cpp @@ -121,9 +121,9 @@ NormalizeSimpleOperationCommonNodeElimination( { auto isCongruent = [&](const Node & node) { - auto & nodeOperation = node.GetOperation(); - return nodeOperation == operation && operands == rvsdg::operands(&node) - && &nodeOperation != &operation; + auto simpleNode = dynamic_cast(&node); + return simpleNode && simpleNode->GetOperation() == operation + && operands == rvsdg::operands(&node) && &simpleNode->GetOperation() != &operation; }; if (operands.empty()) diff --git a/jlm/rvsdg/simple-node.hpp b/jlm/rvsdg/simple-node.hpp index 5dd3159d8..0c64c94a9 100644 --- a/jlm/rvsdg/simple-node.hpp +++ b/jlm/rvsdg/simple-node.hpp @@ -36,7 +36,7 @@ class SimpleNode final : public Node output(size_t index) const noexcept; [[nodiscard]] const SimpleOperation & - GetOperation() const noexcept override; + GetOperation() const noexcept; Node * copy(rvsdg::Region * region, const std::vector & operands) const override; @@ -217,6 +217,17 @@ CreateOpNode(Region & region, OperatorArguments... operatorArguments) {}); } +template +static inline bool +is(const Node * node) noexcept +{ + if (!node) + return false; + + auto simple_node = dynamic_cast(node); + return simple_node && dynamic_cast(&simple_node->GetOperation()); +} + } #endif diff --git a/jlm/rvsdg/structural-node.hpp b/jlm/rvsdg/structural-node.hpp index 2fddbeec1..f13e835fb 100644 --- a/jlm/rvsdg/structural-node.hpp +++ b/jlm/rvsdg/structural-node.hpp @@ -31,6 +31,9 @@ class StructuralNode : public Node std::string DebugString() const override; + [[nodiscard]] virtual const StructuralOperation & + GetOperation() const noexcept = 0; + inline size_t nsubregions() const noexcept { diff --git a/tests/jlm/hls/opt/InvariantLambdaMemoryStateRemovalTests.cpp b/tests/jlm/hls/opt/InvariantLambdaMemoryStateRemovalTests.cpp index 5507debb2..fdf670b9c 100644 --- a/tests/jlm/hls/opt/InvariantLambdaMemoryStateRemovalTests.cpp +++ b/tests/jlm/hls/opt/InvariantLambdaMemoryStateRemovalTests.cpp @@ -63,7 +63,7 @@ TestEliminateSplitAndMergeNodes() assert(lambdaSubregion->nresults() == 1); assert(is(lambdaSubregion->result(0)->Type())); auto loadNode = - jlm::rvsdg::TryGetOwnerNode(*lambdaSubregion->result(0)->origin()); + jlm::rvsdg::TryGetOwnerNode(*lambdaSubregion->result(0)->origin()); assert(is(loadNode->GetOperation())); jlm::util::AssertedCast(loadNode->input(1)->origin()); @@ -136,14 +136,15 @@ TestInvariantMemoryState() assert(is(lambdaSubregion->result(0)->Type())); // Since there is more than one invariant memory state edge, the MemoryStateMerge node should // still exists - auto node = jlm::rvsdg::TryGetOwnerNode(*lambdaSubregion->result(0)->origin()); + auto node = + jlm::rvsdg::TryGetOwnerNode(*lambdaSubregion->result(0)->origin()); assert(is(node->GetOperation())); assert(node->ninputs() == 2); // Need to pass a load node to reach the MemoryStateSplit node - node = jlm::rvsdg::TryGetOwnerNode(*node->input(1)->origin()); + node = jlm::rvsdg::TryGetOwnerNode(*node->input(1)->origin()); assert(is(node->GetOperation())); // Check that the MemoryStateSplit node is still present - node = jlm::rvsdg::TryGetOwnerNode(*node->input(1)->origin()); + node = jlm::rvsdg::TryGetOwnerNode(*node->input(1)->origin()); assert(is(node->GetOperation())); return 0; diff --git a/tests/jlm/llvm/ir/operators/LoadTests.cpp b/tests/jlm/llvm/ir/operators/LoadTests.cpp index 8b5c753b0..d6b4aa731 100644 --- a/tests/jlm/llvm/ir/operators/LoadTests.cpp +++ b/tests/jlm/llvm/ir/operators/LoadTests.cpp @@ -66,12 +66,14 @@ TestCopy() auto loadResults = LoadNonVolatileOperation::Create(address1, { memoryState1 }, valueType, 4); // Act - auto node = jlm::rvsdg::TryGetOwnerNode(*loadResults[0]); + auto node = jlm::rvsdg::TryGetOwnerNode(*loadResults[0]); assert(is(node)); auto copiedNode = node->copy(&graph.GetRootRegion(), { address2, memoryState2 }); // Assert - assert(node->GetOperation() == copiedNode->GetOperation()); + assert( + node->GetOperation() + == jlm::util::AssertedCast(copiedNode)->GetOperation()); return 0; } @@ -638,7 +640,8 @@ NodeCopy() auto copiedNode = loadNode.copy(&graph.GetRootRegion(), { &address2, &iOState2, &memoryState2 }); // Assert - auto copiedOperation = dynamic_cast(&copiedNode->GetOperation()); + auto copiedOperation = dynamic_cast( + &jlm::util::AssertedCast(copiedNode)->GetOperation()); assert(copiedOperation != nullptr); assert(LoadOperation::AddressInput(*copiedNode).origin() == &address2); assert(LoadVolatileOperation::IOStateInput(*copiedNode).origin() == &iOState2); diff --git a/tests/jlm/llvm/ir/operators/StoreTests.cpp b/tests/jlm/llvm/ir/operators/StoreTests.cpp index a8c392c25..fac498ee0 100644 --- a/tests/jlm/llvm/ir/operators/StoreTests.cpp +++ b/tests/jlm/llvm/ir/operators/StoreTests.cpp @@ -194,11 +194,13 @@ TestCopy() auto storeResults = StoreNonVolatileOperation::Create(address1, value1, { memoryState1 }, 4); // Act - auto node = jlm::rvsdg::TryGetOwnerNode(*storeResults[0]); + auto node = jlm::rvsdg::TryGetOwnerNode(*storeResults[0]); auto copiedNode = node->copy(&graph.GetRootRegion(), { address2, value2, memoryState2 }); // Assert - assert(node->GetOperation() == copiedNode->GetOperation()); + assert( + node->GetOperation() + == jlm::util::AssertedCast(copiedNode)->GetOperation()); return 0; } @@ -240,9 +242,9 @@ TestStoreMuxNormalization() auto muxNode = jlm::rvsdg::TryGetOwnerNode(*ex.origin()); assert(is(muxNode)); assert(muxNode->ninputs() == 3); - auto n0 = jlm::rvsdg::TryGetOwnerNode(*muxNode->input(0)->origin()); - auto n1 = jlm::rvsdg::TryGetOwnerNode(*muxNode->input(1)->origin()); - auto n2 = jlm::rvsdg::TryGetOwnerNode(*muxNode->input(2)->origin()); + auto n0 = jlm::rvsdg::TryGetOwnerNode(*muxNode->input(0)->origin()); + auto n1 = jlm::rvsdg::TryGetOwnerNode(*muxNode->input(1)->origin()); + auto n2 = jlm::rvsdg::TryGetOwnerNode(*muxNode->input(2)->origin()); assert(jlm::rvsdg::is(n0->GetOperation())); assert(jlm::rvsdg::is(n1->GetOperation())); assert(jlm::rvsdg::is(n2->GetOperation())); diff --git a/tests/jlm/llvm/ir/operators/TestCall.cpp b/tests/jlm/llvm/ir/operators/TestCall.cpp index d7433fd36..11e0841c0 100644 --- a/tests/jlm/llvm/ir/operators/TestCall.cpp +++ b/tests/jlm/llvm/ir/operators/TestCall.cpp @@ -41,12 +41,14 @@ TestCopy() CallOperation::Create(function1, functionType, { value1, iOState1, memoryState1 }); // Act - auto node = jlm::rvsdg::TryGetOwnerNode(*callResults[0]); + auto node = jlm::rvsdg::TryGetOwnerNode(*callResults[0]); auto copiedNode = node->copy(&rvsdg.GetRootRegion(), { function2, value2, iOState2, memoryState2 }); // Assert - assert(node->GetOperation() == copiedNode->GetOperation()); + assert( + node->GetOperation() + == jlm::util::AssertedCast(copiedNode)->GetOperation()); } static void diff --git a/tests/jlm/llvm/ir/operators/test-sext.cpp b/tests/jlm/llvm/ir/operators/test-sext.cpp index 994ed3d3a..772f72ec9 100644 --- a/tests/jlm/llvm/ir/operators/test-sext.cpp +++ b/tests/jlm/llvm/ir/operators/test-sext.cpp @@ -36,7 +36,7 @@ test_bitunary_reduction() // Act ReduceNode( NormalizeUnaryOperation, - *jlm::rvsdg::TryGetOwnerNode(*ex.origin())); + *jlm::rvsdg::TryGetOwnerNode(*ex.origin())); graph.PruneNodes(); view(graph, stdout); @@ -68,7 +68,7 @@ test_bitbinary_reduction() // Act ReduceNode( NormalizeUnaryOperation, - *jlm::rvsdg::TryGetOwnerNode(*ex.origin())); + *jlm::rvsdg::TryGetOwnerNode(*ex.origin())); graph.PruneNodes(); view(graph, stdout); @@ -97,7 +97,9 @@ test_inverse_reduction() view(graph, stdout); // Act - ReduceNode(NormalizeUnaryOperation, *jlm::rvsdg::TryGetOwnerNode(*ex.origin())); + ReduceNode( + NormalizeUnaryOperation, + *jlm::rvsdg::TryGetOwnerNode(*ex.origin())); graph.PruneNodes(); view(graph, stdout); diff --git a/tests/jlm/llvm/opt/IfConversionTests.cpp b/tests/jlm/llvm/opt/IfConversionTests.cpp index cf5a87240..c28b77b08 100644 --- a/tests/jlm/llvm/opt/IfConversionTests.cpp +++ b/tests/jlm/llvm/opt/IfConversionTests.cpp @@ -127,10 +127,11 @@ EmptyGammaWithTwoSubregionsAndMatch() assert(selectNode->input(2)->origin() == falseValue); const auto eqNode = - jlm::rvsdg::TryGetOwnerNode(*selectNode->input(0)->origin()); + jlm::rvsdg::TryGetOwnerNode(*selectNode->input(0)->origin()); assert(eqNode && is(eqNode)); - auto constantNode = jlm::rvsdg::TryGetOwnerNode(*eqNode->input(0)->origin()); + auto constantNode = + jlm::rvsdg::TryGetOwnerNode(*eqNode->input(0)->origin()); if (constantNode) { assert(eqNode->input(1)->origin() == conditionValue); @@ -142,7 +143,7 @@ EmptyGammaWithTwoSubregionsAndMatch() else { assert(eqNode->input(0)->origin() == conditionValue); - constantNode = jlm::rvsdg::TryGetOwnerNode(*eqNode->input(1)->origin()); + constantNode = jlm::rvsdg::TryGetOwnerNode(*eqNode->input(1)->origin()); auto constantOperation = dynamic_cast(&constantNode->GetOperation()); assert(constantOperation); diff --git a/tests/jlm/mlir/TestIntegerOperationsJlmToMlirToJlm.cpp b/tests/jlm/mlir/TestIntegerOperationsJlmToMlirToJlm.cpp index eba39245c..5cd440237 100644 --- a/tests/jlm/mlir/TestIntegerOperationsJlmToMlirToJlm.cpp +++ b/tests/jlm/mlir/TestIntegerOperationsJlmToMlirToJlm.cpp @@ -97,7 +97,9 @@ TestIntegerBinaryOperation() bool foundBinaryOp = false; for (auto & node : region->Nodes()) { - auto convertedBinaryOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedBinaryOp = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedBinaryOp) { assert(convertedBinaryOp->nresults() == 1); @@ -225,7 +227,9 @@ TestIntegerComparisonOperation(const IntegerComparisonOpTest & tes bool foundCompOp = false; for (auto & node : region->Nodes()) { - auto convertedCompOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedCompOp = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedCompOp) { assert(convertedCompOp->nresults() == 1); diff --git a/tests/jlm/mlir/TestJlmToMlirToJlm.cpp b/tests/jlm/mlir/TestJlmToMlirToJlm.cpp index 8b13429b3..b79840e38 100644 --- a/tests/jlm/mlir/TestJlmToMlirToJlm.cpp +++ b/tests/jlm/mlir/TestJlmToMlirToJlm.cpp @@ -61,8 +61,9 @@ TestUndef() assert(region->nnodes() == 1); // Get the undef op - auto convertedUndef = - dynamic_cast(®ion->Nodes().begin()->GetOperation()); + auto convertedUndef = dynamic_cast( + &jlm::util::AssertedCast(&*region->Nodes().begin()) + ->GetOperation()); assert(convertedUndef != nullptr); @@ -139,7 +140,9 @@ TestAlloca() bool foundAlloca = false; for (auto & node : region->Nodes()) { - if (auto allocaOp = dynamic_cast(&node.GetOperation())) + auto simpleNode = dynamic_cast(&node); + if (auto allocaOp = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr)) { assert(allocaOp->alignment() == 4); @@ -241,14 +244,14 @@ TestLoad() assert(region->nnodes() == 1); auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); assert(convertedLambda->subregion()->nnodes() == 1); - assert(is( - convertedLambda->subregion()->Nodes().begin()->GetOperation())); + assert(is(jlm::util::AssertedCast( + &*convertedLambda->subregion()->Nodes().begin()) + ->GetOperation())); auto convertedLoad = convertedLambda->subregion()->Nodes().begin().ptr(); - auto loadOperation = - dynamic_cast(&convertedLoad->GetOperation()); + auto loadOperation = dynamic_cast( + &jlm::util::AssertedCast(convertedLoad)->GetOperation()); assert(loadOperation->GetAlignment() == 4); assert(loadOperation->NumMemoryStates() == 1); @@ -338,14 +341,14 @@ TestStore() assert(region->nnodes() == 1); auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); assert(convertedLambda->subregion()->nnodes() == 1); - assert(is( - convertedLambda->subregion()->Nodes().begin()->GetOperation())); + assert(is(jlm::util::AssertedCast( + &*convertedLambda->subregion()->Nodes().begin()) + ->GetOperation())); auto convertedStore = convertedLambda->subregion()->Nodes().begin().ptr(); - auto convertedStoreOperation = - dynamic_cast(&convertedStore->GetOperation()); + auto convertedStoreOperation = dynamic_cast( + &jlm::util::AssertedCast(convertedStore)->GetOperation()); assert(convertedStoreOperation->GetAlignment() == 4); assert(convertedStoreOperation->NumMemoryStates() == 1); @@ -425,12 +428,13 @@ TestSext() assert(region->nnodes() == 1); auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); assert(convertedLambda->subregion()->nnodes() == 1); - assert(is(convertedLambda->subregion()->Nodes().begin()->GetOperation())); - auto convertedSext = dynamic_cast( - &convertedLambda->subregion()->Nodes().begin()->GetOperation()); + auto convertedSext = + dynamic_cast(&jlm::util::AssertedCast( + &*convertedLambda->subregion()->Nodes().begin()) + ->GetOperation()); + assert(convertedSext); assert(convertedSext->ndstbits() == 64); assert(convertedSext->nsrcbits() == 32); @@ -501,9 +505,11 @@ TestSitofp() auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); assert(convertedLambda->subregion()->nnodes() == 1); - assert(is(convertedLambda->subregion()->Nodes().begin()->GetOperation())); auto convertedSitofp = dynamic_cast( - &convertedLambda->subregion()->Nodes().begin()->GetOperation()); + &jlm::util::AssertedCast( + &*convertedLambda->subregion()->Nodes().begin()) + ->GetOperation()); + assert(convertedSitofp); assert(jlm::rvsdg::is(*convertedSitofp->argument(0).get())); assert(jlm::rvsdg::is(*convertedSitofp->result(0).get())); @@ -560,9 +566,11 @@ TestConstantFP() auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); assert(convertedLambda->subregion()->nnodes() == 1); - assert(is(convertedLambda->subregion()->Nodes().begin()->GetOperation())); - auto convertedConst = dynamic_cast( - &convertedLambda->subregion()->Nodes().begin()->GetOperation()); + auto convertedConst = + dynamic_cast(&jlm::util::AssertedCast( + &*convertedLambda->subregion()->Nodes().begin()) + ->GetOperation()); + assert(convertedConst); assert(jlm::rvsdg::is(*convertedConst->result(0).get())); assert(convertedConst->constant().isExactlyValue(2.0)); @@ -643,7 +651,8 @@ TestFpBinary() assert(convertedLambda->subregion()->nnodes() == 1); auto node = convertedLambda->subregion()->Nodes().begin().ptr(); - auto convertedFpbin = jlm::util::AssertedCast(&node->GetOperation()); + auto convertedFpbin = jlm::util::AssertedCast( + &jlm::util::AssertedCast(&*node)->GetOperation()); assert(convertedFpbin->fpop() == binOp); assert(convertedFpbin->nresults() == 1); assert(convertedFpbin->narguments() == 2); @@ -732,8 +741,9 @@ TestGetElementPtr() assert(convertedLambda->subregion()->nnodes() == 1); auto op = convertedLambda->subregion()->Nodes().begin(); - assert(is(op->GetOperation())); - auto convertedGep = dynamic_cast(&op->GetOperation()); + auto convertedGep = dynamic_cast( + &jlm::util::AssertedCast(&*op)->GetOperation()); + assert(convertedGep); assert(is(convertedGep->GetPointeeType())); assert(is(convertedGep->result(0))); @@ -847,7 +857,8 @@ TestDelta() assert(convertedDelta->Section() == "section"); auto op = convertedDelta->subregion()->Nodes().begin(); - assert(is(op->GetOperation())); + assert(is( + jlm::util::AssertedCast(&*op)->GetOperation())); } } } @@ -912,7 +923,9 @@ TestConstantDataArray() bool foundConstantDataArray = false; for (auto & node : region->Nodes()) { - if (auto constantDataArray = dynamic_cast(&node.GetOperation())) + auto simpleNode = dynamic_cast(&node); + if (auto constantDataArray = dynamic_cast( + simpleNode ? &simpleNode->GetOperation() : nullptr)) { foundConstantDataArray = true; assert(constantDataArray->nresults() == 1); @@ -977,7 +990,8 @@ TestConstantAggregateZero() assert(region->nnodes() == 1); auto const convertedConstantAggregateZero = jlm::util::AssertedCast( - ®ion->Nodes().begin().ptr()->GetOperation()); + &jlm::util::AssertedCast(&*region->Nodes().begin().ptr()) + ->GetOperation()); assert(convertedConstantAggregateZero->nresults() == 1); assert(convertedConstantAggregateZero->narguments() == 0); auto resultType = convertedConstantAggregateZero->result(0); @@ -1044,7 +1058,9 @@ TestVarArgList() bool foundVarArgOp = false; for (auto & node : region->Nodes()) { - auto convertedVarArgOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedVarArgOp = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedVarArgOp) { assert(convertedVarArgOp->nresults() == 1); @@ -1119,7 +1135,9 @@ TestFNeg() bool foundFNegOp = false; for (auto & node : region->Nodes()) { - auto convertedFNegOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedFNegOp = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedFNegOp) { assert(convertedFNegOp->nresults() == 1); @@ -1197,7 +1215,9 @@ TestFPExt() bool foundFPExtOp = false; for (auto & node : region->Nodes()) { - auto convertedFPExtOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedFPExtOp = dynamic_cast( + simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedFPExtOp) { assert(convertedFPExtOp->nresults() == 1); @@ -1274,7 +1294,9 @@ TestTrunc() bool foundTruncOp = false; for (auto & node : region->Nodes()) { - auto convertedTruncOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedTruncOp = dynamic_cast( + simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedTruncOp) { assert(convertedTruncOp->nresults() == 1); @@ -1365,12 +1387,13 @@ TestFree() assert(region->nnodes() == 1); auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); assert(convertedLambda->subregion()->nnodes() == 1); - assert(is(convertedLambda->subregion()->Nodes().begin()->GetOperation())); - auto convertedFree = dynamic_cast( - &convertedLambda->subregion()->Nodes().begin()->GetOperation()); + auto convertedFree = + dynamic_cast(&jlm::util::AssertedCast( + &*convertedLambda->subregion()->Nodes().begin()) + ->GetOperation()); + assert(convertedFree); assert(convertedFree->narguments() == 3); assert(convertedFree->nresults() == 2); @@ -1626,9 +1649,11 @@ TestIOBarrier() // Find the IOBarrier in the lambda subregion bool foundIOBarrier = false; - for (auto & lambdaNode : lambdaOperation->subregion()->Nodes()) + for (auto & node : lambdaOperation->subregion()->Nodes()) { - auto ioBarrierOp = dynamic_cast(&lambdaNode.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto ioBarrierOp = dynamic_cast( + simpleNode ? &simpleNode->GetOperation() : nullptr); if (ioBarrierOp) { foundIOBarrier = true; @@ -1712,7 +1737,9 @@ TestMalloc() bool foundMallocOp = false; for (auto & node : region->Nodes()) { - auto convertedMallocOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedMallocOp = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedMallocOp) { assert(convertedMallocOp->nresults() == 2); diff --git a/tests/jlm/mlir/frontend/TestMlirToJlmConverter.cpp b/tests/jlm/mlir/frontend/TestMlirToJlmConverter.cpp index 7d6cd4631..a18392d63 100644 --- a/tests/jlm/mlir/frontend/TestMlirToJlmConverter.cpp +++ b/tests/jlm/mlir/frontend/TestMlirToJlmConverter.cpp @@ -267,17 +267,15 @@ TestDivOperation() // Get the lambda block auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); + assert(is(convertedLambda->GetOperation())); // 2 Constants + 1 DivUIOp assert(convertedLambda->subregion()->nnodes() == 3); // Traverse the rvsgd graph upwards to check connections - jlm::rvsdg::node_output * lambdaResultOriginNodeOuput; - assert( - lambdaResultOriginNodeOuput = dynamic_cast( - convertedLambda->subregion()->result(0)->origin())); - Node * lambdaResultOriginNode = lambdaResultOriginNodeOuput->node(); + auto lambdaResultOriginNode = jlm::rvsdg::TryGetOwnerNode( + *convertedLambda->subregion()->result(0)->origin()); + assert(lambdaResultOriginNode); assert(is(lambdaResultOriginNode->GetOperation())); assert(lambdaResultOriginNode->ninputs() == 2); @@ -441,7 +439,7 @@ TestCompZeroExt() // Get the lambda block auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); + assert(is(convertedLambda->GetOperation())); // 2 Constants + AddOp + CompOp + ZeroExtOp assert(convertedLambda->subregion()->nnodes() == 5); @@ -646,7 +644,7 @@ TestMatchOp() // Get the lambda block auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); + assert(is(convertedLambda->GetOperation())); auto lambdaRegion = convertedLambda->subregion(); diff --git a/tests/jlm/rvsdg/SimpleOperationTests.cpp b/tests/jlm/rvsdg/SimpleOperationTests.cpp index 7657e1e4b..f006fa851 100644 --- a/tests/jlm/rvsdg/SimpleOperationTests.cpp +++ b/tests/jlm/rvsdg/SimpleOperationTests.cpp @@ -51,10 +51,18 @@ NormalizeSimpleOperationCne_NodesWithoutOperands() }; // Act - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exNullaryValueNode1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exNullaryValueNode2.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exNullaryStateNode1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exNullaryStateNode2.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exNullaryValueNode1.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exNullaryValueNode2.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exNullaryStateNode1.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exNullaryStateNode2.origin())); graph.PruneNodes(); view(graph, stdout); @@ -106,10 +114,10 @@ NormalizeSimpleOperationCne_NodesWithOperands() }; // Act - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exValueNode1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exValueNode2.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exStateNode1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exStateNode2.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*exValueNode1.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*exValueNode2.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*exStateNode1.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*exStateNode2.origin())); graph.PruneNodes(); view(graph, stdout); @@ -163,10 +171,18 @@ NormalizeSimpleOperationCne_Failure() }; // Act - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exNullaryValueNode.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exNullaryStateNode.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exUnaryValueNode.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exUnaryStateNode.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exNullaryValueNode.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exNullaryStateNode.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exUnaryValueNode.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exUnaryStateNode.origin())); graph.PruneNodes(); view(graph, stdout); diff --git a/tests/jlm/rvsdg/bitstring/bitstring.cpp b/tests/jlm/rvsdg/bitstring/bitstring.cpp index 06f897ae0..cfdea0647 100644 --- a/tests/jlm/rvsdg/bitstring/bitstring.cpp +++ b/tests/jlm/rvsdg/bitstring/bitstring.cpp @@ -40,8 +40,8 @@ types_bitstring_arithmetic_test_bitand() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitand_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, +1)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitand_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, +1)); return 0; } @@ -86,11 +86,11 @@ types_bitstring_arithmetic_test_bitashr() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitashr_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 4)); - assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == int_constant_op(32, 0)); - assert(TryGetOwnerNode(*ex3.origin())->GetOperation() == int_constant_op(32, -4)); - assert(TryGetOwnerNode(*ex4.origin())->GetOperation() == int_constant_op(32, -1)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitashr_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 4)); + assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == int_constant_op(32, 0)); + assert(TryGetOwnerNode(*ex3.origin())->GetOperation() == int_constant_op(32, -4)); + assert(TryGetOwnerNode(*ex4.origin())->GetOperation() == int_constant_op(32, -1)); return 0; } @@ -118,7 +118,7 @@ types_bitstring_arithmetic_test_bitdifference() view(&graph.GetRootRegion(), stdout); // Act - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsub_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsub_op(32)); return 0; } @@ -152,9 +152,9 @@ types_bitstring_arithmetic_test_bitnegate() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitneg_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -3)); - assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == int_constant_op(32, 3)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitneg_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -3)); + assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == int_constant_op(32, 3)); return 0; } @@ -188,9 +188,9 @@ types_bitstring_arithmetic_test_bitnot() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitnot_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -4)); - assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == int_constant_op(32, 3)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitnot_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -4)); + assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == int_constant_op(32, 3)); return 0; } @@ -224,8 +224,8 @@ types_bitstring_arithmetic_test_bitor() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitor_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 7)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitor_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 7)); return 0; } @@ -259,8 +259,8 @@ types_bitstring_arithmetic_test_bitproduct() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitmul_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 15)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitmul_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 15)); return 0; } @@ -288,7 +288,7 @@ types_bitstring_arithmetic_test_bitshiproduct() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsmulh_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsmulh_op(32)); return 0; } @@ -326,9 +326,9 @@ types_bitstring_arithmetic_test_bitshl() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitshl_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 64)); - assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == uint_constant_op(32, 0)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitshl_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 64)); + assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == uint_constant_op(32, 0)); return 0; } @@ -366,9 +366,9 @@ types_bitstring_arithmetic_test_bitshr() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitshr_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 4)); - assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == uint_constant_op(32, 0)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitshr_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 4)); + assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == uint_constant_op(32, 0)); return 0; } @@ -402,8 +402,8 @@ types_bitstring_arithmetic_test_bitsmod() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsmod_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -1)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsmod_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -1)); return 0; } @@ -437,8 +437,8 @@ types_bitstring_arithmetic_test_bitsquotient() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsdiv_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -2)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsdiv_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -2)); return 0; } @@ -472,8 +472,8 @@ types_bitstring_arithmetic_test_bitsum() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitadd_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 8)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitadd_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 8)); return 0; } @@ -501,7 +501,7 @@ types_bitstring_arithmetic_test_bituhiproduct() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitumulh_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitumulh_op(32)); return 0; } @@ -535,8 +535,8 @@ types_bitstring_arithmetic_test_bitumod() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitumod_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 1)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitumod_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 1)); return 0; } @@ -570,8 +570,8 @@ types_bitstring_arithmetic_test_bituquotient() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitudiv_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 2)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitudiv_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 2)); return 0; } @@ -605,8 +605,8 @@ types_bitstring_arithmetic_test_bitxor() view(&graph.GetRootRegion(), stdout); // Arrange - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitxor_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 6)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitxor_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 6)); return 0; } @@ -614,7 +614,7 @@ types_bitstring_arithmetic_test_bitxor() static inline void expect_static_true(jlm::rvsdg::output * port) { - auto node = jlm::rvsdg::TryGetOwnerNode(*port); + auto node = jlm::rvsdg::TryGetOwnerNode(*port); auto op = dynamic_cast(&node->GetOperation()); assert(op && op->value().nbits() == 1 && op->value().str() == "1"); } @@ -622,7 +622,7 @@ expect_static_true(jlm::rvsdg::output * port) static inline void expect_static_false(jlm::rvsdg::output * port) { - auto node = jlm::rvsdg::TryGetOwnerNode(*port); + auto node = jlm::rvsdg::TryGetOwnerNode(*port); auto op = dynamic_cast(&node->GetOperation()); assert(op && op->value().nbits() == 1 && op->value().str() == "0"); } @@ -662,10 +662,10 @@ types_bitstring_comparison_test_bitequal() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == biteq_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == biteq_op(32)); expect_static_true(ex1.origin()); expect_static_false(ex2.origin()); - assert(TryGetOwnerNode(*ex3.origin())->GetOperation() == biteq_op(32)); + assert(TryGetOwnerNode(*ex3.origin())->GetOperation() == biteq_op(32)); return 0; } @@ -705,10 +705,10 @@ types_bitstring_comparison_test_bitnotequal() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitne_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitne_op(32)); expect_static_false(ex1.origin()); expect_static_true(ex2.origin()); - assert(TryGetOwnerNode(*ex3.origin())->GetOperation() == bitne_op(32)); + assert(TryGetOwnerNode(*ex3.origin())->GetOperation() == bitne_op(32)); return 0; } @@ -753,7 +753,7 @@ types_bitstring_comparison_test_bitsgreater() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsgt_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsgt_op(32)); expect_static_false(ex1.origin()); expect_static_true(ex2.origin()); expect_static_false(ex3.origin()); @@ -804,7 +804,7 @@ types_bitstring_comparison_test_bitsgreatereq() view(&graph.GetRootRegion(), stdout); // Arrange - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsge_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsge_op(32)); expect_static_false(ex1.origin()); expect_static_true(ex2.origin()); expect_static_true(ex3.origin()); @@ -854,7 +854,7 @@ types_bitstring_comparison_test_bitsless() view(&graph.GetRootRegion(), stdout); // Arrange - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitslt_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitslt_op(32)); expect_static_true(ex1.origin()); expect_static_false(ex2.origin()); expect_static_false(ex3.origin()); @@ -906,7 +906,7 @@ types_bitstring_comparison_test_bitslesseq() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsle_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsle_op(32)); expect_static_true(ex1.origin()); expect_static_true(ex2.origin()); expect_static_false(ex3.origin()); @@ -955,7 +955,7 @@ types_bitstring_comparison_test_bitugreater() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitugt_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitugt_op(32)); expect_static_false(ex1.origin()); expect_static_true(ex2.origin()); expect_static_false(ex3.origin()); @@ -1006,7 +1006,7 @@ types_bitstring_comparison_test_bitugreatereq() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bituge_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bituge_op(32)); expect_static_false(ex1.origin()); expect_static_true(ex2.origin()); expect_static_true(ex3.origin()); @@ -1056,7 +1056,7 @@ types_bitstring_comparison_test_bituless() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitult_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitult_op(32)); expect_static_true(ex1.origin()); expect_static_false(ex2.origin()); expect_static_false(ex3.origin()); @@ -1108,7 +1108,7 @@ types_bitstring_comparison_test_bitulesseq() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitule_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitule_op(32)); expect_static_true(ex1.origin()); expect_static_true(ex2.origin()); expect_static_false(ex3.origin()); @@ -1179,19 +1179,19 @@ types_bitstring_test_constant() assert(b1.GetOperation() == uint_constant_op(8, 204)); assert(b1.GetOperation() == int_constant_op(8, -52)); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex2.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex3.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex4.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex1.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex2.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex3.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex4.origin())); assert(ex1.origin() == ex2.origin()); assert(ex1.origin() == ex3.origin()); - const auto node1 = TryGetOwnerNode(*ex1.origin()); + const auto node1 = TryGetOwnerNode(*ex1.origin()); assert(node1->GetOperation() == uint_constant_op(8, 204)); assert(node1->GetOperation() == int_constant_op(8, -52)); - const auto node4 = TryGetOwnerNode(*ex4.origin()); + const auto node4 = TryGetOwnerNode(*ex4.origin()); assert(node4->GetOperation() == uint_constant_op(9, 204)); assert(node4->GetOperation() == int_constant_op(9, 204)); @@ -1230,14 +1230,14 @@ types_bitstring_test_normalize() // Act ReduceNode(FlattenAssociativeBinaryOperation, sum1); - auto & flattenedBinaryNode = *TryGetOwnerNode(*ex.origin()); + auto & flattenedBinaryNode = *TryGetOwnerNode(*ex.origin()); ReduceNode(NormalizeFlattenedBinaryOperation, flattenedBinaryNode); graph.PruneNodes(); view(&graph.GetRootRegion(), stdout); // Assert - auto node = TryGetOwnerNode(*ex.origin()); + auto node = TryGetOwnerNode(*ex.origin()); assert(node->GetOperation() == bitadd_op(32)); assert(node->ninputs() == 2); auto op1 = node->input(0)->origin(); @@ -1249,7 +1249,7 @@ types_bitstring_test_normalize() op2 = tmp; } /* FIXME: the graph traversers are currently broken, that is why it won't normalize */ - assert(TryGetOwnerNode(*op1)->GetOperation() == int_constant_op(32, 3 + 4)); + assert(TryGetOwnerNode(*op1)->GetOperation() == int_constant_op(32, 3 + 4)); assert(op2 == imp); view(&graph.GetRootRegion(), stdout); @@ -1260,7 +1260,7 @@ types_bitstring_test_normalize() static void assert_constant(jlm::rvsdg::output * bitstr, size_t nbits, const char bits[]) { - auto node = jlm::rvsdg::TryGetOwnerNode(*bitstr); + auto node = jlm::rvsdg::TryGetOwnerNode(*bitstr); auto op = dynamic_cast(node->GetOperation()); assert(op.value() == jlm::rvsdg::bitvalue_repr(std::string(bits, nbits).c_str())); } @@ -1402,7 +1402,7 @@ ConcatOfSliceReduction() view(&graph.GetRootRegion(), stdout); // Assert - const auto sliceNode = TryGetOwnerNode(*ex.origin()); + const auto sliceNode = TryGetOwnerNode(*ex.origin()); assert(sliceNode->GetOperation() == bitslice_op(bit16Type, 0, 16)); assert(sliceNode->input(0)->origin() == x); @@ -1435,7 +1435,7 @@ SliceOfConstant() view(graph, stdout); // Assert - const auto node = TryGetOwnerNode(*ex.origin()); + const auto node = TryGetOwnerNode(*ex.origin()); auto & operation = dynamic_cast(node->GetOperation()); assert(operation.value() == bitvalue_repr("1101")); @@ -1468,7 +1468,7 @@ SliceOfSlice() view(graph, stdout); // Assert - const auto node = TryGetOwnerNode(*ex.origin()); + const auto node = TryGetOwnerNode(*ex.origin()); const auto operation = dynamic_cast(&node->GetOperation()); assert(operation->low() == 3 && operation->high() == 5); @@ -1527,11 +1527,11 @@ SliceOfConcat() // Act ReduceNode(NormalizeUnaryOperation, sliceNode); - auto concatNode = TryGetOwnerNode(*ex.origin()); + auto concatNode = TryGetOwnerNode(*ex.origin()); ReduceNode( NormalizeUnaryOperation, - *TryGetOwnerNode(*concatNode->input(0)->origin())); - concatNode = TryGetOwnerNode(*ex.origin()); + *TryGetOwnerNode(*concatNode->input(0)->origin())); + concatNode = TryGetOwnerNode(*ex.origin()); ReduceNode(NormalizeBinaryOperation, *concatNode); graph.PruneNodes(); @@ -1565,13 +1565,13 @@ ConcatFlattening() view(graph, stdout); // Act - const auto concatNode = TryGetOwnerNode(*ex.origin()); + const auto concatNode = TryGetOwnerNode(*ex.origin()); ReduceNode(FlattenBitConcatOperation, *concatNode); view(graph, stdout); // Assert - auto node = TryGetOwnerNode(*ex.origin()); + auto node = TryGetOwnerNode(*ex.origin()); assert(dynamic_cast(&node->GetOperation())); assert(node->ninputs() == 3); assert(node->input(0)->origin() == x); @@ -1637,7 +1637,7 @@ ConcatOfSlices() // Act ReduceNode(NormalizeBinaryOperation, concatNode); - ReduceNode(NormalizeUnaryOperation, *TryGetOwnerNode(*ex.origin())); + ReduceNode(NormalizeUnaryOperation, *TryGetOwnerNode(*ex.origin())); graph.PruneNodes(); view(graph, stdout); @@ -1666,10 +1666,10 @@ ConcatOfConstants() view(graph, stdout); // Act - ReduceNode(NormalizeBinaryOperation, *TryGetOwnerNode(*ex.origin())); + ReduceNode(NormalizeBinaryOperation, *TryGetOwnerNode(*ex.origin())); // Assert - auto node = TryGetOwnerNode(*ex.origin()); + auto node = TryGetOwnerNode(*ex.origin()); auto operation = dynamic_cast(node->GetOperation()); assert(operation.value() == bitvalue_repr("0011011111001000")); @@ -1709,8 +1709,8 @@ ConcatCne() view(graph, stdout); // Act - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex2.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex1.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex2.origin())); graph.PruneNodes(); view(graph, stdout); @@ -1752,8 +1752,8 @@ SliceCne() view(graph, stdout); // Act - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex2.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex1.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex2.origin())); graph.PruneNodes(); view(graph, stdout); diff --git a/tests/jlm/rvsdg/test-cse.cpp b/tests/jlm/rvsdg/test-cse.cpp index eb2d54cb7..87b8594ca 100644 --- a/tests/jlm/rvsdg/test-cse.cpp +++ b/tests/jlm/rvsdg/test-cse.cpp @@ -41,29 +41,29 @@ test_main() auto & e4 = jlm::tests::GraphExport::Create(*o4, "o4"); // Act & Assert - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e2.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e3.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e4.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e1.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e2.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e3.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e4.origin())); assert(e1.origin() == e3.origin()); assert(e2.origin() == e4.origin()); auto o5 = jlm::tests::create_testop(&graph.GetRootRegion(), {}, { valueType })[0]; auto & e5 = jlm::tests::GraphExport::Create(*o5, "o5"); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e5.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e5.origin())); assert(e5.origin() == e1.origin()); auto o6 = jlm::tests::create_testop(&graph.GetRootRegion(), { i }, { valueType })[0]; auto & e6 = jlm::tests::GraphExport::Create(*o6, "o6"); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e6.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e6.origin())); assert(e6.origin() == e2.origin()); auto o7 = jlm::tests::create_testop(&graph.GetRootRegion(), {}, { valueType })[0]; auto & e7 = jlm::tests::GraphExport::Create(*o7, "o7"); assert(e7.origin() != e1.origin()); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e7.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e7.origin())); assert(e7.origin() == e1.origin()); return 0; diff --git a/tests/jlm/rvsdg/test-gamma.cpp b/tests/jlm/rvsdg/test-gamma.cpp index dde2897b6..12833a262 100644 --- a/tests/jlm/rvsdg/test-gamma.cpp +++ b/tests/jlm/rvsdg/test-gamma.cpp @@ -152,7 +152,7 @@ test_control_constant_reduction() view(&graph.GetRootRegion(), stdout); // Assert - auto match = TryGetOwnerNode(*ex1.origin()); + auto match = TryGetOwnerNode(*ex1.origin()); assert(match && is(match->GetOperation())); auto & match_op = to_match_op(match->GetOperation()); assert(match_op.default_alternative() == 0);