diff --git a/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.cpp b/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.cpp index 5fffd499e..af54ef179 100644 --- a/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.cpp +++ b/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.cpp @@ -1153,10 +1153,10 @@ RhlsToFirrtlConverter::MlirGenHlsLocalMem(const jlm::rvsdg::SimpleNode * node) { auto lmem_op = dynamic_cast(&(node->GetOperation())); JLM_ASSERT(lmem_op); - auto res_node = rvsdg::TryGetOwnerNode(**node->output(0)->begin()); + auto res_node = rvsdg::TryGetOwnerNode(**node->output(0)->begin()); auto res_op = dynamic_cast(&res_node->GetOperation()); JLM_ASSERT(res_op); - auto req_node = rvsdg::TryGetOwnerNode(**node->output(1)->begin()); + auto req_node = rvsdg::TryGetOwnerNode(**node->output(1)->begin()); auto req_op = dynamic_cast(&req_node->GetOperation()); JLM_ASSERT(req_op); // Create the module and its input/output ports - we use a non-standard way here @@ -2788,8 +2788,8 @@ RhlsToFirrtlConverter::createInstances( { if (auto sn = dynamic_cast(node)) { - if (dynamic_cast(&(node->GetOperation())) - || dynamic_cast(&(node->GetOperation()))) + if (dynamic_cast(&(sn->GetOperation())) + || dynamic_cast(&(sn->GetOperation()))) { // these are virtual - connections go to local_mem instead continue; @@ -3993,7 +3993,7 @@ RhlsToFirrtlConverter::GetFirrtlType(const jlm::rvsdg::Type * type) } std::string -RhlsToFirrtlConverter::GetModuleName(const rvsdg::Node * node) +RhlsToFirrtlConverter::GetModuleName(const rvsdg::SimpleNode * node) { std::string append = ""; @@ -4085,7 +4085,7 @@ RhlsToFirrtlConverter::IsIdentityMapping(const jlm::rvsdg::match_op & op) void RhlsToFirrtlConverter::WriteModuleToFile( const circt::firrtl::FModuleOp fModuleOp, - const rvsdg::Node * node) + const rvsdg::SimpleNode * node) { if (!fModuleOp) return; diff --git a/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.hpp b/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.hpp index ac61d9b41..662e61422 100644 --- a/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.hpp +++ b/jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.hpp @@ -68,7 +68,7 @@ class RhlsToFirrtlConverter : public BaseHLS MlirGen(const rvsdg::LambdaNode * lamdaNode); void - WriteModuleToFile(const circt::firrtl::FModuleOp fModuleOp, const rvsdg::Node * node); + WriteModuleToFile(const circt::firrtl::FModuleOp fModuleOp, const rvsdg::SimpleNode * node); void WriteCircuitToFile(const circt::firrtl::CircuitOp circuit, std::string name); @@ -283,7 +283,7 @@ class RhlsToFirrtlConverter : public BaseHLS circt::firrtl::FIRRTLBaseType GetFirrtlType(const jlm::rvsdg::Type * type); std::string - GetModuleName(const rvsdg::Node * node); + GetModuleName(const rvsdg::SimpleNode * node); bool IsIdentityMapping(const jlm::rvsdg::match_op & op); diff --git a/jlm/hls/backend/rhls2firrtl/dot-hls.cpp b/jlm/hls/backend/rhls2firrtl/dot-hls.cpp index a770455bc..62b9ad727 100644 --- a/jlm/hls/backend/rhls2firrtl/dot-hls.cpp +++ b/jlm/hls/backend/rhls2firrtl/dot-hls.cpp @@ -248,8 +248,10 @@ DotHLS::loop_to_dot(hls::loop_node * ln) dot << "{rank=same "; for (auto node : rvsdg::TopDownTraverser(sr)) { - auto mx = dynamic_cast(&node->GetOperation()); - auto lc = dynamic_cast(&node->GetOperation()); + auto simpleNode = dynamic_cast(node); + auto mx = dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); + auto lc = dynamic_cast( + simpleNode ? &simpleNode->GetOperation() : nullptr); if ((mx && !mx->discarding && mx->loop) || lc) { dot << get_node_name(node) << " "; @@ -260,7 +262,9 @@ DotHLS::loop_to_dot(hls::loop_node * ln) dot << "{rank=same "; for (auto node : rvsdg::TopDownTraverser(sr)) { - auto br = dynamic_cast(&node->GetOperation()); + auto simpleNode = dynamic_cast(node); + auto br = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); if (br && br->loop) { dot << get_node_name(node) << " "; @@ -272,9 +276,9 @@ DotHLS::loop_to_dot(hls::loop_node * ln) // do edges outside in order not to pull other nodes into the cluster for (auto node : rvsdg::TopDownTraverser(sr)) { - if (dynamic_cast(node)) + if (auto simpleNode = dynamic_cast(node)) { - auto mx = dynamic_cast(&node->GetOperation()); + auto mx = dynamic_cast(&simpleNode->GetOperation()); auto node_name = get_node_name(node); for (size_t i = 0; i < node->ninputs(); ++i) { diff --git a/jlm/hls/backend/rvsdg2rhls/add-prints.cpp b/jlm/hls/backend/rvsdg2rhls/add-prints.cpp index ebf7a4342..2b82901d4 100644 --- a/jlm/hls/backend/rvsdg2rhls/add-prints.cpp +++ b/jlm/hls/backend/rvsdg2rhls/add-prints.cpp @@ -115,20 +115,23 @@ convert_prints( convert_prints(structnode->subregion(n), printf, functionType); } } - else if (auto po = dynamic_cast(&(node->GetOperation()))) + else if (auto simpleNode = dynamic_cast(node)) { - auto printf_local = route_to_region_rvsdg(printf, region); // TODO: prevent repetition? - auto & constantNode = llvm::IntegerConstantOperation::Create(*region, 64, po->id()); - jlm::rvsdg::output * val = node->input(0)->origin(); - if (*val->Type() != *jlm::rvsdg::bittype::Create(64)) + if (auto po = dynamic_cast(&(simpleNode->GetOperation()))) { - auto bt = std::dynamic_pointer_cast(val->Type()); - JLM_ASSERT(bt); - val = &llvm::ZExtOperation::Create(*val, rvsdg::bittype::Create(64)); + auto printf_local = route_to_region_rvsdg(printf, region); // TODO: prevent repetition? + auto & constantNode = llvm::IntegerConstantOperation::Create(*region, 64, po->id()); + jlm::rvsdg::output * val = node->input(0)->origin(); + if (*val->Type() != *jlm::rvsdg::bittype::Create(64)) + { + auto bt = std::dynamic_pointer_cast(val->Type()); + JLM_ASSERT(bt); + val = &llvm::ZExtOperation::Create(*val, rvsdg::bittype::Create(64)); + } + llvm::CallOperation::Create(printf_local, functionType, { constantNode.output(0), val }); + node->output(0)->divert_users(node->input(0)->origin()); + jlm::rvsdg::remove(node); } - llvm::CallOperation::Create(printf_local, functionType, { constantNode.output(0), val }); - node->output(0)->divert_users(node->input(0)->origin()); - jlm::rvsdg::remove(node); } } } diff --git a/jlm/hls/backend/rvsdg2rhls/alloca-conv.cpp b/jlm/hls/backend/rvsdg2rhls/alloca-conv.cpp index e26543080..1579374b7 100644 --- a/jlm/hls/backend/rvsdg2rhls/alloca-conv.cpp +++ b/jlm/hls/backend/rvsdg2rhls/alloca-conv.cpp @@ -127,99 +127,101 @@ alloca_conv(rvsdg::Region * region) alloca_conv(structnode->subregion(n)); } } - else if (auto po = dynamic_cast(&(node->GetOperation()))) + else if (auto simpleNode = dynamic_cast(node)) { - // ensure that the size is one - JLM_ASSERT(node->ninputs() == 1); - auto constant_output = dynamic_cast(node->input(0)->origin()); - JLM_ASSERT(constant_output); - auto constant_operation = dynamic_cast( - &constant_output->node()->GetOperation()); - JLM_ASSERT(constant_operation); - JLM_ASSERT(constant_operation->Representation().to_uint() == 1); - // ensure that the alloca is an array type - auto at = std::dynamic_pointer_cast(po->ValueType()); - JLM_ASSERT(at); - // detect loads and stores attached to alloca - TraceAllocaUses ta(node->output(0)); - // create memory + response - auto mem_outs = local_mem_op::create(at, node->region()); - auto resp_outs = local_mem_resp_op::create(*mem_outs[0], ta.load_nodes.size()); - std::cout << "alloca converted " << at->debug_string() << std::endl; - // replace gep outputs (convert pointer to index calculation) - // replace loads and stores - std::vector load_addrs; - for (auto l : ta.load_nodes) + if (auto po = dynamic_cast(&(simpleNode->GetOperation()))) { - auto index = gep_to_index(l->input(0)->origin()); - auto response = route_response_rhls(l->region(), resp_outs.front()); - resp_outs.erase(resp_outs.begin()); - std::vector states; - for (size_t i = 1; i < l->ninputs(); ++i) + // ensure that the size is one + JLM_ASSERT(node->ninputs() == 1); + auto & constant_node = + rvsdg::AssertGetOwnerNode(*node->input(0)->origin()); + auto constant_operation = + util::AssertedCast(&constant_node.GetOperation()); + JLM_ASSERT(constant_operation->Representation().to_uint() == 1); + // ensure that the alloca is an array type + auto at = std::dynamic_pointer_cast(po->ValueType()); + JLM_ASSERT(at); + // detect loads and stores attached to alloca + TraceAllocaUses ta(node->output(0)); + // create memory + response + auto mem_outs = local_mem_op::create(at, node->region()); + auto resp_outs = local_mem_resp_op::create(*mem_outs[0], ta.load_nodes.size()); + std::cout << "alloca converted " << at->debug_string() << std::endl; + // replace gep outputs (convert pointer to index calculation) + // replace loads and stores + std::vector load_addrs; + for (auto l : ta.load_nodes) { - states.push_back(l->input(i)->origin()); + auto index = gep_to_index(l->input(0)->origin()); + auto response = route_response_rhls(l->region(), resp_outs.front()); + resp_outs.erase(resp_outs.begin()); + std::vector states; + for (size_t i = 1; i < l->ninputs(); ++i) + { + states.push_back(l->input(i)->origin()); + } + auto load_outs = local_load_op::create(*index, states, *response); + auto nn = dynamic_cast(load_outs[0])->node(); + for (size_t i = 0; i < l->noutputs(); ++i) + { + l->output(i)->divert_users(nn->output(i)); + } + remove(l); + auto addr = route_request_rhls(node->region(), load_outs.back()); + load_addrs.push_back(addr); } - auto load_outs = local_load_op::create(*index, states, *response); - auto nn = dynamic_cast(load_outs[0])->node(); - for (size_t i = 0; i < l->noutputs(); ++i) + std::vector store_operands; + for (auto s : ta.store_nodes) { - l->output(i)->divert_users(nn->output(i)); + auto index = gep_to_index(s->input(0)->origin()); + std::vector states; + for (size_t i = 2; i < s->ninputs(); ++i) + { + states.push_back(s->input(i)->origin()); + } + auto store_outs = local_store_op::create(*index, *s->input(1)->origin(), states); + auto nn = dynamic_cast(store_outs[0])->node(); + for (size_t i = 0; i < s->noutputs(); ++i) + { + s->output(i)->divert_users(nn->output(i)); + } + remove(s); + auto addr = route_request_rhls(node->region(), store_outs[store_outs.size() - 2]); + auto data = route_request_rhls(node->region(), store_outs.back()); + store_operands.push_back(addr); + store_operands.push_back(data); } - remove(l); - auto addr = route_request_rhls(node->region(), load_outs.back()); - load_addrs.push_back(addr); - } - std::vector store_operands; - for (auto s : ta.store_nodes) - { - auto index = gep_to_index(s->input(0)->origin()); - std::vector states; - for (size_t i = 2; i < s->ninputs(); ++i) + // TODO: ensure that loads/stores are either alloca or global, never both + // TODO: ensure that loads/stores have same width and alignment and geps can be merged - + // otherwise slice? create request + auto req_outs = local_mem_req_op::create(*mem_outs[1], load_addrs, store_operands); + + // remove alloca from memstate merge + // TODO: handle general case of other nodes getting state edge without a merge + JLM_ASSERT(node->output(1)->nusers() == 1); + auto merge_in = *node->output(1)->begin(); + auto merge_node = rvsdg::TryGetOwnerNode(*merge_in); + if (dynamic_cast(&merge_node->GetOperation())) { - states.push_back(s->input(i)->origin()); + // merge after alloca -> remove merge + JLM_ASSERT(merge_node->ninputs() == 2); + auto other_index = merge_in->index() ? 0 : 1; + merge_node->output(0)->divert_users(merge_node->input(other_index)->origin()); + jlm::rvsdg::remove(merge_node); } - auto store_outs = local_store_op::create(*index, *s->input(1)->origin(), states); - auto nn = dynamic_cast(store_outs[0])->node(); - for (size_t i = 0; i < s->noutputs(); ++i) + else { - s->output(i)->divert_users(nn->output(i)); + // TODO: fix this properly by adding a state edge to the LambdaEntryMemState and routing + // it to the region + JLM_ASSERT(false); } - remove(s); - auto addr = route_request_rhls(node->region(), store_outs[store_outs.size() - 2]); - auto data = route_request_rhls(node->region(), store_outs.back()); - store_operands.push_back(addr); - store_operands.push_back(data); - } - // TODO: ensure that loads/stores are either alloca or global, never both - // TODO: ensure that loads/stores have same width and alignment and geps can be merged - - // otherwise slice? create request - auto req_outs = local_mem_req_op::create(*mem_outs[1], load_addrs, store_operands); - // remove alloca from memstate merge - // TODO: handle general case of other nodes getting state edge without a merge - JLM_ASSERT(node->output(1)->nusers() == 1); - auto merge_in = *node->output(1)->begin(); - auto merge_node = rvsdg::TryGetOwnerNode(*merge_in); - if (dynamic_cast(&merge_node->GetOperation())) - { - // merge after alloca -> remove merge - JLM_ASSERT(merge_node->ninputs() == 2); - auto other_index = merge_in->index() ? 0 : 1; - merge_node->output(0)->divert_users(merge_node->input(other_index)->origin()); - jlm::rvsdg::remove(merge_node); + // TODO: run dne to + // remove loads/stores + // remove geps + // remove alloca pointer users + // remove alloca } - else - { - // TODO: fix this properly by adding a state edge to the LambdaEntryMemState and routing it - // to the region - JLM_ASSERT(false); - } - - // TODO: run dne to - // remove loads/stores - // remove geps - // remove alloca pointer users - // remove alloca } } } diff --git a/jlm/hls/backend/rvsdg2rhls/instrument-ref.cpp b/jlm/hls/backend/rvsdg2rhls/instrument-ref.cpp index 0c103df02..252c1be63 100644 --- a/jlm/hls/backend/rvsdg2rhls/instrument-ref.cpp +++ b/jlm/hls/backend/rvsdg2rhls/instrument-ref.cpp @@ -163,89 +163,93 @@ instrument_ref( allocaFunctionType); } } - else if ( - auto loadOp = - dynamic_cast(&(node->GetOperation()))) + else if (auto simpleNode = dynamic_cast(node)) { - auto addr = node->input(0)->origin(); - JLM_ASSERT(rvsdg::is(addr->Type())); - size_t bitWidth = BaseHLS::JlmSize(&*loadOp->GetLoadedType()); - int log2Bytes = log2(bitWidth / 8); - auto & widthNode = llvm::IntegerConstantOperation::Create(*region, 64, log2Bytes); - - // Does this IF make sense now when the void_ptr doesn't have a type? - if (*addr->Type() != *void_ptr) - { - addr = jlm::llvm::bitcast_op::create(addr, void_ptr); - } - auto memstate = node->input(1)->origin(); - auto callOp = jlm::llvm::CallOperation::Create( - load_func, - loadFunctionType, - { addr, widthNode.output(0), ioState, memstate }); - // Divert the memory state of the load to the new memstate from the call operation - node->input(1)->divert_to(callOp[1]); - } - else if (auto ao = dynamic_cast(&(node->GetOperation()))) - { - // ensure that the size is one - JLM_ASSERT(node->ninputs() == 1); - auto constant_output = dynamic_cast(node->input(0)->origin()); - JLM_ASSERT(constant_output); - auto constant_operation = dynamic_cast( - &constant_output->node()->GetOperation()); - JLM_ASSERT(constant_operation); - JLM_ASSERT(constant_operation->Representation().to_uint() == 1); - jlm::rvsdg::output * addr = node->output(0); - // ensure that the alloca is an array type - JLM_ASSERT(jlm::rvsdg::is(addr->Type())); - auto at = dynamic_cast(&ao->value_type()); - JLM_ASSERT(at); - auto & sizeNode = - llvm::IntegerConstantOperation::Create(*region, 64, BaseHLS::JlmSize(at) / 8); - - // Does this IF make sense now when the void_ptr doesn't have a type? - if (*addr->Type() != *void_ptr) - { - addr = jlm::llvm::bitcast_op::create(addr, void_ptr); - } - std::vector old_users(node->output(1)->begin(), node->output(1)->end()); - auto memstate = node->output(1); - auto callOp = jlm::llvm::CallOperation::Create( - alloca_func, - allocaFunctionType, - { addr, sizeNode.output(0), ioState, memstate }); - for (auto ou : old_users) + if (auto loadOp = dynamic_cast( + &(simpleNode->GetOperation()))) { + auto addr = node->input(0)->origin(); + JLM_ASSERT(std::dynamic_pointer_cast(addr->Type())); + size_t bitWidth = BaseHLS::JlmSize(&*loadOp->GetLoadedType()); + int log2Bytes = log2(bitWidth / 8); + auto & widthNode = llvm::IntegerConstantOperation::Create(*region, 64, log2Bytes); + + // Does this IF make sense now when the void_ptr doesn't have a type? + if (*addr->Type() != *void_ptr) + { + addr = jlm::llvm::bitcast_op::create(addr, void_ptr); + } + auto memstate = node->input(1)->origin(); + auto callOp = jlm::llvm::CallOperation::Create( + load_func, + loadFunctionType, + { addr, widthNode.output(0), ioState, memstate }); // Divert the memory state of the load to the new memstate from the call operation - ou->divert_to(callOp[1]); + node->input(1)->divert_to(callOp[1]); } - } - else if ( - auto so = - dynamic_cast(&(node->GetOperation()))) - { - auto addr = node->input(0)->origin(); - JLM_ASSERT(rvsdg::is(addr->Type())); - auto bitWidth = JlmSize(&so->GetStoredType()); - int log2Bytes = log2(bitWidth / 8); - auto & widthNode = llvm::IntegerConstantOperation::Create(*region, 64, log2Bytes); - - // Does this IF make sense now when the void_ptr doesn't have a type? - if (*addr->Type() != *void_ptr) + else if (auto ao = dynamic_cast(&(simpleNode->GetOperation()))) { - addr = jlm::llvm::bitcast_op::create(addr, void_ptr); + // ensure that the size is one + JLM_ASSERT(node->ninputs() == 1); + auto & constant_node = + rvsdg::AssertGetOwnerNode(*node->input(0)->origin()); + auto constant_operation = + util::AssertedCast(&constant_node.GetOperation()); + JLM_ASSERT(constant_operation->Representation().to_uint() == 1); + jlm::rvsdg::output * addr = node->output(0); + // ensure that the alloca is an array type + auto pt = std::dynamic_pointer_cast(addr->Type()); + JLM_ASSERT(pt); + auto at = dynamic_cast(&ao->value_type()); + JLM_ASSERT(at); + auto & sizeNode = + llvm::IntegerConstantOperation::Create(*region, 64, BaseHLS::JlmSize(at) / 8); + + // Does this IF make sense now when the void_ptr doesn't have a type? + if (*addr->Type() != *void_ptr) + { + addr = jlm::llvm::bitcast_op::create(addr, void_ptr); + } + std::vector old_users( + node->output(1)->begin(), + node->output(1)->end()); + auto memstate = node->output(1); + auto callOp = jlm::llvm::CallOperation::Create( + alloca_func, + allocaFunctionType, + { addr, sizeNode.output(0), ioState, memstate }); + for (auto ou : old_users) + { + // Divert the memory state of the load to the new memstate from the call operation + ou->divert_to(callOp[1]); + } } - auto memstate = node->output(0); - std::vector oldUsers(memstate->begin(), memstate->end()); - auto callOp = jlm::llvm::CallOperation::Create( - store_func, - storeFunctionType, - { addr, widthNode.output(0), ioState, memstate }); - // Divert the memory state after the store to the new memstate from the call operation - for (auto user : oldUsers) + else if ( + auto so = dynamic_cast( + &(simpleNode->GetOperation()))) { - user->divert_to(callOp[1]); + auto addr = node->input(0)->origin(); + JLM_ASSERT(std::dynamic_pointer_cast(addr->Type())); + auto bitWidth = JlmSize(&so->GetStoredType()); + int log2Bytes = log2(bitWidth / 8); + auto & widthNode = llvm::IntegerConstantOperation::Create(*region, 64, log2Bytes); + + // Does this IF make sense now when the void_ptr doesn't have a type? + if (*addr->Type() != *void_ptr) + { + addr = jlm::llvm::bitcast_op::create(addr, void_ptr); + } + auto memstate = node->output(0); + std::vector oldUsers(memstate->begin(), memstate->end()); + auto callOp = jlm::llvm::CallOperation::Create( + store_func, + storeFunctionType, + { addr, widthNode.output(0), ioState, memstate }); + // Divert the memory state after the store to the new memstate from the call operation + for (auto user : oldUsers) + { + user->divert_to(callOp[1]); + } } } } diff --git a/jlm/hls/backend/rvsdg2rhls/mem-queue.cpp b/jlm/hls/backend/rvsdg2rhls/mem-queue.cpp index df624545e..119b45b8a 100644 --- a/jlm/hls/backend/rvsdg2rhls/mem-queue.cpp +++ b/jlm/hls/backend/rvsdg2rhls/mem-queue.cpp @@ -507,7 +507,7 @@ jlm::hls::mem_queue(jlm::rvsdg::Region * region) // Check if there exists a memory state splitter if (state_arg->nusers() == 1) { - auto entryNode = rvsdg::TryGetOwnerNode(**state_arg->begin()); + auto entryNode = rvsdg::TryGetOwnerNode(**state_arg->begin()); if (jlm::rvsdg::is( entryNode->GetOperation())) { diff --git a/jlm/hls/backend/rvsdg2rhls/remove-redundant-buf.cpp b/jlm/hls/backend/rvsdg2rhls/remove-redundant-buf.cpp index 883463026..5918e6fb7 100644 --- a/jlm/hls/backend/rvsdg2rhls/remove-redundant-buf.cpp +++ b/jlm/hls/backend/rvsdg2rhls/remove-redundant-buf.cpp @@ -58,9 +58,9 @@ remove_redundant_buf(rvsdg::Region * region) remove_redundant_buf(structnode->subregion(n)); } } - else if (dynamic_cast(node)) + else if (auto simplenode = dynamic_cast(node)) { - if (auto buf = dynamic_cast(&node->GetOperation())) + if (auto buf = dynamic_cast(&simplenode->GetOperation())) { if (std::dynamic_pointer_cast(buf->argument(0))) { diff --git a/jlm/hls/backend/rvsdg2rhls/rhls-dne.cpp b/jlm/hls/backend/rvsdg2rhls/rhls-dne.cpp index de4776c32..3e57cb997 100644 --- a/jlm/hls/backend/rvsdg2rhls/rhls-dne.cpp +++ b/jlm/hls/backend/rvsdg2rhls/rhls-dne.cpp @@ -126,7 +126,7 @@ remove_unused_loop_inputs(loop_node * ln) } bool -dead_spec_gamma(rvsdg::Node * dmux_node) +dead_spec_gamma(rvsdg::SimpleNode * dmux_node) { auto mux_op = dynamic_cast(&dmux_node->GetOperation()); JLM_ASSERT(mux_op); @@ -152,7 +152,7 @@ dead_spec_gamma(rvsdg::Node * dmux_node) } bool -dead_nonspec_gamma(rvsdg::Node * ndmux_node) +dead_nonspec_gamma(rvsdg::SimpleNode * ndmux_node) { auto mux_op = dynamic_cast(&ndmux_node->GetOperation()); JLM_ASSERT(mux_op); @@ -194,7 +194,7 @@ dead_nonspec_gamma(rvsdg::Node * ndmux_node) } bool -dead_loop(rvsdg::Node * ndmux_node) +dead_loop(rvsdg::SimpleNode * ndmux_node) { auto mux_op = dynamic_cast(&ndmux_node->GetOperation()); JLM_ASSERT(mux_op); @@ -282,19 +282,23 @@ dne(rvsdg::Region * sr) { if (!node->has_users()) { - if (dynamic_cast(&node->GetOperation())) + if (auto simpleNode = dynamic_cast(node)) { - // TODO: fix this once memory connections are explicit - continue; - } - else if (dynamic_cast(&node->GetOperation())) - { - continue; - } - else if (dynamic_cast(&node->GetOperation())) - { - // TODO: fix - this scenario has only stores and should just be optimized away completely - continue; + if (dynamic_cast(&simpleNode->GetOperation())) + { + // TODO: fix this once memory connections are explicit + continue; + } + else if (dynamic_cast(&simpleNode->GetOperation())) + { + continue; + } + else if (dynamic_cast(&simpleNode->GetOperation())) + { + // TODO: fix - this scenario has only stores and should just be optimized away + // completely + continue; + } } remove(node); changed = true; @@ -307,15 +311,18 @@ dne(rvsdg::Region * sr) changed |= remove_loop_passthrough(ln); changed |= dne(ln->subregion()); } - else if (auto mux = dynamic_cast(&node->GetOperation())) + else if (auto simpleNode = dynamic_cast(node)) { - if (mux->discarding) - { - changed |= dead_spec_gamma(node); - } - else + if (auto mux = dynamic_cast(&simpleNode->GetOperation())) { - changed |= dead_nonspec_gamma(node) || dead_loop(node); + if (mux->discarding) + { + changed |= dead_spec_gamma(simpleNode); + } + else + { + changed |= dead_nonspec_gamma(simpleNode) || dead_loop(simpleNode); + } } } } diff --git a/jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp b/jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp index f19f4c489..e7961509f 100644 --- a/jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp +++ b/jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp @@ -150,30 +150,33 @@ inline_calls(rvsdg::Region * region) inline_calls(structnode->subregion(n)); } } - else if (dynamic_cast(&(node->GetOperation()))) + else if (auto simpleNode = dynamic_cast(node)) { - auto traced = jlm::hls::trace_call(node->input(0)); - auto so = dynamic_cast(traced); - if (!so) + if (dynamic_cast(&(simpleNode->GetOperation()))) { - if (auto graphImport = dynamic_cast(traced)) + auto traced = jlm::hls::trace_call(node->input(0)); + auto so = dynamic_cast(traced); + if (!so) { - if (graphImport->Name().rfind("decouple_", 0) == 0) + if (auto graphImport = dynamic_cast(traced)) { - // can't inline pseudo functions used for decoupling - continue; + if (graphImport->Name().rfind("decouple_", 0) == 0) + { + // can't inline pseudo functions used for decoupling + continue; + } + throw jlm::util::error("can not inline external function " + graphImport->Name()); } - throw jlm::util::error("can not inline external function " + graphImport->Name()); } + JLM_ASSERT(dynamic_cast(so->node())); + auto ln = dynamic_cast(traced)->node(); + llvm::inlineCall( + dynamic_cast(node), + dynamic_cast(ln)); + // restart for this region + inline_calls(region); + return; } - JLM_ASSERT(rvsdg::is(so->node())); - auto ln = dynamic_cast(traced)->node(); - llvm::inlineCall( - dynamic_cast(node), - dynamic_cast(ln)); - // restart for this region - inline_calls(region); - return; } } } @@ -192,57 +195,60 @@ convert_alloca(rvsdg::Region * region) convert_alloca(structnode->subregion(n)); } } - else if (auto po = dynamic_cast(&(node->GetOperation()))) + else if (auto simpleNode = dynamic_cast(node)) { - auto rr = ®ion->graph()->GetRootRegion(); - auto delta_name = jlm::util::strfmt("hls_alloca_", alloca_cnt++); - auto delta_type = llvm::PointerType::Create(); - std::cout << "alloca " << delta_name << ": " << po->value_type().debug_string() << "\n"; - auto db = llvm::delta::node::Create( - rr, - std::static_pointer_cast(po->ValueType()), - delta_name, - llvm::linkage::external_linkage, - "", - false); - // create zero constant of allocated type - jlm::rvsdg::output * cout; - if (auto bt = dynamic_cast(&po->value_type())) + if (auto po = dynamic_cast(&(simpleNode->GetOperation()))) { - cout = llvm::IntegerConstantOperation::Create( - *db->subregion(), - bt->Representation().nbits(), - 0) - .output(0); - } - else - { - cout = llvm::ConstantAggregateZeroOperation::Create(*db->subregion(), po->ValueType()); - } - auto delta = db->finalize(cout); - jlm::llvm::GraphExport::Create(*delta, delta_name); - auto delta_local = route_to_region_rvsdg(delta, region); - node->output(0)->divert_users(delta_local); - // TODO: check that the input to alloca is a bitconst 1 - // TODO: handle general case of other nodes getting state edge without a merge - JLM_ASSERT(node->output(1)->nusers() == 1); - auto mux_in = *node->output(1)->begin(); - auto mux_node = rvsdg::TryGetOwnerNode(*mux_in); - if (dynamic_cast(&mux_node->GetOperation())) - { - // merge after alloca -> remove merge - JLM_ASSERT(mux_node->ninputs() == 2); - auto other_index = mux_in->index() ? 0 : 1; - mux_node->output(0)->divert_users(mux_node->input(other_index)->origin()); - jlm::rvsdg::remove(mux_node); - } - else - { - // TODO: fix this properly by adding a state edge to the LambdaEntryMemState and routing it - // to the region - JLM_ASSERT(false); + auto rr = ®ion->graph()->GetRootRegion(); + auto delta_name = jlm::util::strfmt("hls_alloca_", alloca_cnt++); + auto delta_type = llvm::PointerType::Create(); + std::cout << "alloca " << delta_name << ": " << po->value_type().debug_string() << "\n"; + auto db = llvm::delta::node::Create( + rr, + std::static_pointer_cast(po->ValueType()), + delta_name, + llvm::linkage::external_linkage, + "", + false); + // create zero constant of allocated type + jlm::rvsdg::output * cout; + if (auto bt = dynamic_cast(&po->value_type())) + { + cout = llvm::IntegerConstantOperation::Create( + *db->subregion(), + bt->Representation().nbits(), + 0) + .output(0); + } + else + { + cout = llvm::ConstantAggregateZeroOperation::Create(*db->subregion(), po->ValueType()); + } + auto delta = db->finalize(cout); + jlm::llvm::GraphExport::Create(*delta, delta_name); + auto delta_local = route_to_region_rvsdg(delta, region); + node->output(0)->divert_users(delta_local); + // TODO: check that the input to alloca is a bitconst 1 + // TODO: handle general case of other nodes getting state edge without a merge + JLM_ASSERT(node->output(1)->nusers() == 1); + auto mux_in = *node->output(1)->begin(); + auto mux_node = rvsdg::TryGetOwnerNode(*mux_in); + if (dynamic_cast(&mux_node->GetOperation())) + { + // merge after alloca -> remove merge + JLM_ASSERT(mux_node->ninputs() == 2); + auto other_index = mux_in->index() ? 0 : 1; + mux_node->output(0)->divert_users(mux_node->input(other_index)->origin()); + jlm::rvsdg::remove(mux_node); + } + else + { + // TODO: fix this properly by adding a state edge to the LambdaEntryMemState and routing + // it to the region + JLM_ASSERT(false); + } + jlm::rvsdg::remove(node); } - jlm::rvsdg::remove(node); } } } diff --git a/jlm/hls/ir/hls.cpp b/jlm/hls/ir/hls.cpp index fb0b247b1..3310fc8ce 100644 --- a/jlm/hls/ir/hls.cpp +++ b/jlm/hls/ir/hls.cpp @@ -95,7 +95,7 @@ loop_node::AddLoopVar(jlm::rvsdg::output * origin, jlm::rvsdg::output ** buffer) return output; } -[[nodiscard]] const rvsdg::Operation & +[[nodiscard]] const rvsdg::StructuralOperation & loop_node::GetOperation() const noexcept { static const loop_op singleton; diff --git a/jlm/hls/ir/hls.hpp b/jlm/hls/ir/hls.hpp index 69db9cddd..a94f98e0c 100644 --- a/jlm/hls/ir/hls.hpp +++ b/jlm/hls/ir/hls.hpp @@ -755,7 +755,7 @@ class loop_node final : public rvsdg::StructuralNode jlm::rvsdg::node_output * _predicate_buffer; public: - [[nodiscard]] const rvsdg::Operation & + [[nodiscard]] const rvsdg::StructuralOperation & GetOperation() const noexcept override; static loop_node * diff --git a/jlm/hls/opt/IOBarrierRemoval.cpp b/jlm/hls/opt/IOBarrierRemoval.cpp index 41711d5af..a531d63ab 100644 --- a/jlm/hls/opt/IOBarrierRemoval.cpp +++ b/jlm/hls/opt/IOBarrierRemoval.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include namespace jlm::hls @@ -33,11 +34,13 @@ IOBarrierRemoval::RemoveIOBarrierFromRegion(rvsdg::Region & region) RemoveIOBarrierFromRegion(*structuralNode->subregion(n)); } } - - // Render all IOBarrier nodes dead - if (rvsdg::is(&node)) + else if (auto simpleNode = dynamic_cast(&node)) { - node.output(0)->divert_users(node.input(0)->origin()); + // Render all IOBarrier nodes dead + if (rvsdg::is(simpleNode)) + { + node.output(0)->divert_users(node.input(0)->origin()); + } } } diff --git a/jlm/hls/opt/cne.cpp b/jlm/hls/opt/cne.cpp index 257055c0f..5c6c58764 100644 --- a/jlm/hls/opt/cne.cpp +++ b/jlm/hls/opt/cne.cpp @@ -272,20 +272,24 @@ congruent(jlm::rvsdg::output * o1, jlm::rvsdg::output * o2, vset & vs, cnectx & } } - if (jlm::rvsdg::is(n1) && jlm::rvsdg::is(n2) - && n1->GetOperation() == n2->GetOperation() && n1->ninputs() == n2->ninputs() - && o1->index() == o2->index()) + if (auto s1 = rvsdg::TryGetOwnerNode(*o1)) { - for (size_t n = 0; n < n1->ninputs(); n++) + if (auto s2 = rvsdg::TryGetOwnerNode(*o2)) { - auto origin1 = n1->input(n)->origin(); - auto origin2 = n2->input(n)->origin(); - if (!congruent(origin1, origin2, vs, ctx)) - return false; + if (s1->GetOperation() == s2->GetOperation() && s1->ninputs() == s2->ninputs() + && o1->index() == o2->index()) + { + for (size_t n = 0; n < s1->ninputs(); n++) + { + auto origin1 = s1->input(n)->origin(); + auto origin2 = s2->input(n)->origin(); + if (!congruent(origin1, origin2, vs, ctx)) + return false; + } + return true; + } } - return true; } - return false; } @@ -459,7 +463,9 @@ mark(const jlm::rvsdg::SimpleNode * node, cnectx & ctx) { for (const auto & other : node->region()->TopNodes()) { - if (&other != node && node->GetOperation() == other.GetOperation()) + auto otherSimpleNode = dynamic_cast(&other); + if (otherSimpleNode && &other != node + && node->GetOperation() == otherSimpleNode->GetOperation()) { ctx.mark(node, &other); break; @@ -473,8 +479,7 @@ mark(const jlm::rvsdg::SimpleNode * node, cnectx & ctx) { for (const auto & user : *origin) { - auto ni = dynamic_cast(user); - auto other = ni ? ni->node() : nullptr; + auto other = rvsdg::TryGetOwnerNode(*user); if (!other || other == node || other->GetOperation() != node->GetOperation() || other->ninputs() != node->ninputs()) continue; diff --git a/jlm/llvm/backend/RvsdgToIpGraphConverter.cpp b/jlm/llvm/backend/RvsdgToIpGraphConverter.cpp index a1c2e3b99..dd830c695 100644 --- a/jlm/llvm/backend/RvsdgToIpGraphConverter.cpp +++ b/jlm/llvm/backend/RvsdgToIpGraphConverter.cpp @@ -162,7 +162,7 @@ RvsdgToIpGraphConverter::CreateInitialization(const delta::node & deltaNode) operands.push_back(Context_->GetVariable(node->input(n)->origin())); // convert node to tac - auto & op = *static_cast(&node->GetOperation()); + auto & op = util::AssertedCast(node)->GetOperation(); tacs.push_back(tac::create(op, operands)); Context_->InsertVariable(output, tacs.back()->result(0)); } diff --git a/jlm/llvm/ir/operators/operators.cpp b/jlm/llvm/ir/operators/operators.cpp index 370c8221d..c4bf408a2 100644 --- a/jlm/llvm/ir/operators/operators.cpp +++ b/jlm/llvm/ir/operators/operators.cpp @@ -408,7 +408,8 @@ ZExtOperation::reduce_operand(rvsdg::unop_reduction_path_t path, rvsdg::output * { if (path == rvsdg::unop_reduction_constant) { - auto c = static_cast(&producer(operand)->GetOperation()); + auto c = util::AssertedCast( + &util::AssertedCast(producer(operand))->GetOperation()); return create_bitconstant( rvsdg::TryGetOwnerNode(*operand)->region(), c->value().zext(ndstbits() - nsrcbits())); diff --git a/jlm/llvm/ir/operators/sext.cpp b/jlm/llvm/ir/operators/sext.cpp index 52707f72c..c3cc76bdd 100644 --- a/jlm/llvm/ir/operators/sext.cpp +++ b/jlm/llvm/ir/operators/sext.cpp @@ -118,7 +118,8 @@ sext_op::reduce_operand(rvsdg::unop_reduction_path_t path, rvsdg::output * opera { if (path == rvsdg::unop_reduction_constant) { - auto c = static_cast(&producer(operand)->GetOperation()); + auto c = util::AssertedCast( + &util::AssertedCast(producer(operand))->GetOperation()); return create_bitconstant(operand->region(), c->value().sext(ndstbits() - nsrcbits())); } diff --git a/jlm/llvm/opt/InvariantValueRedirection.cpp b/jlm/llvm/opt/InvariantValueRedirection.cpp index bfd39ca99..a4aee740e 100644 --- a/jlm/llvm/opt/InvariantValueRedirection.cpp +++ b/jlm/llvm/opt/InvariantValueRedirection.cpp @@ -86,11 +86,17 @@ InvariantValueRedirection::RedirectInRootRegion(rvsdg::Graph & rvsdg) // Nothing needs to be done. // Delta nodes are irrelevant for invariant value redirection. } - else if ( - is(node->GetOperation()) - || is(node->GetOperation())) + else if (auto simpleNode = dynamic_cast(node)) { - // Nothing needs to be done. + if (is(simpleNode->GetOperation()) + || is(simpleNode->GetOperation())) + { + // Nothing needs to be done. + } + else + { + JLM_UNREACHABLE("Unhandled node type."); + } } else { @@ -125,9 +131,12 @@ InvariantValueRedirection::RedirectInRegion(rvsdg::Region & region) RedirectInSubregions(*thetaNode); RedirectThetaOutputs(*thetaNode); } - else if (is(&node)) + else if (auto simpleNode = dynamic_cast(&node)) { - RedirectCallOutputs(*util::AssertedCast(&node)); + if (is(simpleNode)) + { + RedirectCallOutputs(*util::AssertedCast(&node)); + } } } } diff --git a/jlm/llvm/opt/alias-analyses/TopDownModRefEliminator.cpp b/jlm/llvm/opt/alias-analyses/TopDownModRefEliminator.cpp index 957aae803..cef90f8fb 100644 --- a/jlm/llvm/opt/alias-analyses/TopDownModRefEliminator.cpp +++ b/jlm/llvm/opt/alias-analyses/TopDownModRefEliminator.cpp @@ -504,11 +504,18 @@ TopDownModRefEliminator::EliminateTopDownRootRegion(rvsdg::Region & region) { // Nothing needs to be done. } - else if ( - is(node->GetOperation()) - || is(node->GetOperation())) + else if (auto simpleNode = dynamic_cast(node)) { - // Nothing needs to be done. + if (is(simpleNode->GetOperation()) + || is(simpleNode->GetOperation())) + + { + // Nothing needs to be done. + } + else + { + JLM_UNREACHABLE("Unhandled node type!"); + } } else { diff --git a/jlm/llvm/opt/cne.cpp b/jlm/llvm/opt/cne.cpp index cf0ded29b..f6f5149c6 100644 --- a/jlm/llvm/opt/cne.cpp +++ b/jlm/llvm/opt/cne.cpp @@ -261,18 +261,23 @@ congruent(jlm::rvsdg::output * o1, jlm::rvsdg::output * o2, vset & vs, cnectx & } } - if (jlm::rvsdg::is(n1) && jlm::rvsdg::is(n2) - && n1->GetOperation() == n2->GetOperation() && n1->ninputs() == n2->ninputs() - && o1->index() == o2->index()) + if (auto simple1 = rvsdg::TryGetOwnerNode(*o1)) { - for (size_t n = 0; n < n1->ninputs(); n++) + if (auto simple2 = rvsdg::TryGetOwnerNode(*o2)) { - auto origin1 = n1->input(n)->origin(); - auto origin2 = n2->input(n)->origin(); - if (!congruent(origin1, origin2, vs, ctx)) - return false; + if (simple1->GetOperation() == simple2->GetOperation() + && simple1->ninputs() == simple2->ninputs() && o1->index() == o2->index()) + { + for (size_t n = 0; n < n1->ninputs(); n++) + { + auto origin1 = n1->input(n)->origin(); + auto origin2 = n2->input(n)->origin(); + if (!congruent(origin1, origin2, vs, ctx)) + return false; + } + return true; + } } - return true; } return false; @@ -423,7 +428,8 @@ mark(const jlm::rvsdg::SimpleNode * node, cnectx & ctx) { for (const auto & other : node->region()->TopNodes()) { - if (&other != node && node->GetOperation() == other.GetOperation()) + auto otherSimple = dynamic_cast(&other); + if (&other != node && otherSimple && node->GetOperation() == otherSimple->GetOperation()) { ctx.mark(node, &other); break; @@ -437,8 +443,7 @@ mark(const jlm::rvsdg::SimpleNode * node, cnectx & ctx) { for (const auto & user : *origin) { - auto ni = dynamic_cast(user); - auto other = ni ? ni->node() : nullptr; + auto other = rvsdg::TryGetOwnerNode(*user); if (!other || other == node || other->GetOperation() != node->GetOperation() || other->ninputs() != node->ninputs()) continue; diff --git a/jlm/llvm/opt/push.cpp b/jlm/llvm/opt/push.cpp index 9593ab7cc..b4eb4b705 100644 --- a/jlm/llvm/opt/push.cpp +++ b/jlm/llvm/opt/push.cpp @@ -322,7 +322,7 @@ is_movable_store(rvsdg::Node * node) } static void -pushout_store(rvsdg::Node * storenode) +pushout_store(rvsdg::SimpleNode * storenode) { JLM_ASSERT(dynamic_cast(storenode->region()->node())); JLM_ASSERT(jlm::rvsdg::is(storenode) && is_movable_store(storenode)); diff --git a/jlm/llvm/opt/reduction.cpp b/jlm/llvm/opt/reduction.cpp index 8c43b98c4..b6f202a83 100644 --- a/jlm/llvm/opt/reduction.cpp +++ b/jlm/llvm/opt/reduction.cpp @@ -84,9 +84,9 @@ NodeReduction::ReduceNodesInRegion(rvsdg::Region & region) { reductionPerformed |= ReduceStructuralNode(*structuralNode); } - else if (rvsdg::is(node)) + else if (const auto simpleNode = dynamic_cast(node)) { - reductionPerformed |= ReduceSimpleNode(*node); + reductionPerformed |= ReduceSimpleNode(*simpleNode); } else { @@ -144,7 +144,7 @@ NodeReduction::ReduceGammaNode(rvsdg::StructuralNode & gammaNode) } bool -NodeReduction::ReduceSimpleNode(rvsdg::Node & simpleNode) +NodeReduction::ReduceSimpleNode(rvsdg::SimpleNode & simpleNode) { if (is(&simpleNode)) { @@ -169,7 +169,7 @@ NodeReduction::ReduceSimpleNode(rvsdg::Node & simpleNode) } bool -NodeReduction::ReduceLoadNode(rvsdg::Node & simpleNode) +NodeReduction::ReduceLoadNode(rvsdg::SimpleNode & simpleNode) { JLM_ASSERT(is(&simpleNode)); @@ -177,7 +177,7 @@ NodeReduction::ReduceLoadNode(rvsdg::Node & simpleNode) } bool -NodeReduction::ReduceStoreNode(rvsdg::Node & simpleNode) +NodeReduction::ReduceStoreNode(rvsdg::SimpleNode & simpleNode) { JLM_ASSERT(is(&simpleNode)); @@ -185,7 +185,7 @@ NodeReduction::ReduceStoreNode(rvsdg::Node & simpleNode) } bool -NodeReduction::ReduceBinaryNode(rvsdg::Node & simpleNode) +NodeReduction::ReduceBinaryNode(rvsdg::SimpleNode & simpleNode) { JLM_ASSERT(is(&simpleNode)); diff --git a/jlm/llvm/opt/reduction.hpp b/jlm/llvm/opt/reduction.hpp index 4a7e53be8..e79e142a1 100644 --- a/jlm/llvm/opt/reduction.hpp +++ b/jlm/llvm/opt/reduction.hpp @@ -6,6 +6,7 @@ #ifndef JLM_LLVM_OPT_REDUCTION_HPP #define JLM_LLVM_OPT_REDUCTION_HPP +#include #include #include @@ -73,16 +74,16 @@ class NodeReduction final : public rvsdg::Transformation ReduceGammaNode(rvsdg::StructuralNode & gammaNode); [[nodiscard]] static bool - ReduceSimpleNode(rvsdg::Node & simpleNode); + ReduceSimpleNode(rvsdg::SimpleNode & simpleNode); [[nodiscard]] static bool - ReduceLoadNode(rvsdg::Node & simpleNode); + ReduceLoadNode(rvsdg::SimpleNode & simpleNode); [[nodiscard]] static bool - ReduceStoreNode(rvsdg::Node & simpleNode); + ReduceStoreNode(rvsdg::SimpleNode & simpleNode); [[nodiscard]] static bool - ReduceBinaryNode(rvsdg::Node & simpleNode); + ReduceBinaryNode(rvsdg::SimpleNode & simpleNode); static std::optional> NormalizeLoadNode( diff --git a/jlm/llvm/opt/unroll.hpp b/jlm/llvm/opt/unroll.hpp index 7aab200b1..33bdf3701 100644 --- a/jlm/llvm/opt/unroll.hpp +++ b/jlm/llvm/opt/unroll.hpp @@ -51,8 +51,8 @@ class unrollinfo final private: inline unrollinfo( - rvsdg::Node * cmpnode, - rvsdg::Node * armnode, + rvsdg::SimpleNode * cmpnode, + rvsdg::SimpleNode * armnode, rvsdg::output * idv, rvsdg::output * step, rvsdg::output * end) @@ -108,7 +108,7 @@ class unrollinfo final std::unique_ptr niterations() const noexcept; - rvsdg::Node * + rvsdg::SimpleNode * cmpnode() const noexcept { return cmpnode_; @@ -117,10 +117,10 @@ class unrollinfo final [[nodiscard]] const rvsdg::SimpleOperation & cmpoperation() const noexcept { - return *static_cast(&cmpnode()->GetOperation()); + return cmpnode()->GetOperation(); } - inline rvsdg::Node * + inline rvsdg::SimpleNode * armnode() const noexcept { return armnode_; @@ -129,7 +129,7 @@ class unrollinfo final [[nodiscard]] const rvsdg::SimpleOperation & armoperation() const noexcept { - return *static_cast(&armnode()->GetOperation()); + return armnode()->GetOperation(); } inline rvsdg::output * @@ -206,10 +206,9 @@ class unrollinfo final inline bool is_known(jlm::rvsdg::output * output) const noexcept { - auto p = producer(output); + auto p = dynamic_cast(producer(output)); if (!p) return false; - auto op = dynamic_cast(&p->GetOperation()); return op && op->value().is_known(); } @@ -220,14 +219,14 @@ class unrollinfo final if (!is_known(output)) return nullptr; - auto p = producer(output); - return &static_cast(&p->GetOperation())->value(); + auto p = util::AssertedCast(producer(output)); + return &util::AssertedCast(&p->GetOperation())->value(); } rvsdg::output * end_; rvsdg::output * step_; - rvsdg::Node * cmpnode_; - rvsdg::Node * armnode_; + rvsdg::SimpleNode * cmpnode_; + rvsdg::SimpleNode * armnode_; rvsdg::output * idv_; }; diff --git a/jlm/rvsdg/NodeNormalization.hpp b/jlm/rvsdg/NodeNormalization.hpp index f4ce375e5..f3437e366 100644 --- a/jlm/rvsdg/NodeNormalization.hpp +++ b/jlm/rvsdg/NodeNormalization.hpp @@ -43,7 +43,7 @@ NormalizeSequence( template bool -ReduceNode(const NodeNormalization & nodeNormalization, Node & node) +ReduceNode(const NodeNormalization & nodeNormalization, SimpleNode & node) { auto operation = util::AssertedCast(&node.GetOperation()); auto operands = rvsdg::operands(&node); diff --git a/jlm/rvsdg/binary.cpp b/jlm/rvsdg/binary.cpp index 4edb79265..6f88f503d 100644 --- a/jlm/rvsdg/binary.cpp +++ b/jlm/rvsdg/binary.cpp @@ -81,10 +81,19 @@ FlattenAssociativeBinaryOperation( if (node == nullptr) return false; - auto flattenedBinaryOperation = - dynamic_cast(&node->GetOperation()); - return node->GetOperation() == operation - || (flattenedBinaryOperation && flattenedBinaryOperation->bin_operation() == operation); + auto simpleNode = dynamic_cast(node); + if (simpleNode) + { + auto flattenedBinaryOperation = + dynamic_cast(&simpleNode->GetOperation()); + return simpleNode->GetOperation() == operation + || (flattenedBinaryOperation + && flattenedBinaryOperation->bin_operation() == operation); + } + else + { + return false; + } }); if (operands == newOperands) @@ -220,12 +229,15 @@ FlattenedBinaryOperation::reduce( { for (auto & node : TopDownTraverser(region)) { - if (is(node)) + if (auto simpleNode = dynamic_cast(node)) { - const auto op = static_cast(&node->GetOperation()); - auto output = op->reduce(reduction, operands(node)); - node->output(0)->divert_users(output); - remove(node); + auto op = dynamic_cast(&simpleNode->GetOperation()); + if (op) + { + auto output = op->reduce(reduction, operands(node)); + node->output(0)->divert_users(output); + remove(node); + } } else if (auto structnode = dynamic_cast(node)) { diff --git a/jlm/rvsdg/bitstring/bitoperation-classes.cpp b/jlm/rvsdg/bitstring/bitoperation-classes.cpp index 297e00ec2..58fe82166 100644 --- a/jlm/rvsdg/bitstring/bitoperation-classes.cpp +++ b/jlm/rvsdg/bitstring/bitoperation-classes.cpp @@ -27,7 +27,7 @@ bitunary_op::reduce_operand(unop_reduction_path_t path, jlm::rvsdg::output * arg { if (path == unop_reduction_constant) { - auto p = producer(arg); + auto p = static_cast(producer(arg)); auto & c = static_cast(p->GetOperation()); return create_bitconstant(p->region(), reduce_constant(c.value())); } @@ -57,8 +57,10 @@ bitbinary_op::reduce_operand_pair( { if (path == binop_reduction_constants) { - auto & c1 = static_cast(producer(arg1)->GetOperation()); - auto & c2 = static_cast(producer(arg2)->GetOperation()); + auto & c1 = static_cast( + static_cast(producer(arg1))->GetOperation()); + auto & c2 = static_cast( + static_cast(producer(arg2))->GetOperation()); return create_bitconstant(arg1->region(), reduce_constants(c1.value(), c2.value())); } @@ -73,12 +75,12 @@ bitcompare_op::can_reduce_operand_pair( const jlm::rvsdg::output * arg1, const jlm::rvsdg::output * arg2) const noexcept { - auto p = producer(arg1); + auto p = dynamic_cast(producer(arg1)); const bitconstant_op * c1_op = nullptr; if (p) c1_op = dynamic_cast(&p->GetOperation()); - p = producer(arg2); + p = dynamic_cast(producer(arg2)); const bitconstant_op * c2_op = nullptr; if (p) c2_op = dynamic_cast(&p->GetOperation()); diff --git a/jlm/rvsdg/bitstring/concat.cpp b/jlm/rvsdg/bitstring/concat.cpp index e6e0a7b96..5d58fd89e 100644 --- a/jlm/rvsdg/bitstring/concat.cpp +++ b/jlm/rvsdg/bitstring/concat.cpp @@ -58,8 +58,8 @@ bitconcat_op::can_reduce_operand_pair( const jlm::rvsdg::output * arg1, const jlm::rvsdg::output * arg2) const noexcept { - auto node1 = TryGetOwnerNode(*arg1); - auto node2 = TryGetOwnerNode(*arg2); + auto node1 = TryGetOwnerNode(*arg1); + auto node2 = TryGetOwnerNode(*arg2); if (!node1 || !node2) return binop_reduction_none; @@ -97,13 +97,13 @@ bitconcat_op::reduce_operand_pair( jlm::rvsdg::output * arg1, jlm::rvsdg::output * arg2) const { - auto node1 = static_cast(arg1)->node(); - auto node2 = static_cast(arg2)->node(); + auto & node1 = AssertGetOwnerNode(*arg1); + auto & node2 = AssertGetOwnerNode(*arg2); if (path == binop_reduction_constants) { - auto & arg1_constant = static_cast(node1->GetOperation()); - auto & arg2_constant = static_cast(node2->GetOperation()); + auto & arg1_constant = static_cast(node1.GetOperation()); + auto & arg2_constant = static_cast(node2.GetOperation()); bitvalue_repr bits(arg1_constant.value()); bits.Append(arg2_constant.value()); @@ -112,9 +112,9 @@ bitconcat_op::reduce_operand_pair( if (path == binop_reduction_merge) { - auto arg1_slice = static_cast(&node1->GetOperation()); - auto arg2_slice = static_cast(&node2->GetOperation()); - return jlm::rvsdg::bitslice(node1->input(0)->origin(), arg1_slice->low(), arg2_slice->high()); + auto arg1_slice = static_cast(&node1.GetOperation()); + auto arg2_slice = static_cast(&node2.GetOperation()); + return jlm::rvsdg::bitslice(node1.input(0)->origin(), arg1_slice->low(), arg2_slice->high()); /* FIXME: support sign bit */ } diff --git a/jlm/rvsdg/bitstring/slice.cpp b/jlm/rvsdg/bitstring/slice.cpp index 64e61943b..242cfd1fc 100644 --- a/jlm/rvsdg/bitstring/slice.cpp +++ b/jlm/rvsdg/bitstring/slice.cpp @@ -57,17 +57,17 @@ bitslice_op::reduce_operand(unop_reduction_path_t path, jlm::rvsdg::output * arg return arg; } - auto node = static_cast(arg)->node(); + auto & node = AssertGetOwnerNode(*arg); if (path == unop_reduction_narrow) { - auto op = static_cast(node->GetOperation()); - return jlm::rvsdg::bitslice(node->input(0)->origin(), low() + op.low(), high() + op.low()); + auto op = static_cast(node.GetOperation()); + return jlm::rvsdg::bitslice(node.input(0)->origin(), low() + op.low(), high() + op.low()); } if (path == unop_reduction_constant) { - auto op = static_cast(node->GetOperation()); + auto op = static_cast(node.GetOperation()); std::string s(&op.value()[0] + low(), high() - low()); return create_bitconstant(arg->region(), s.c_str()); } @@ -76,9 +76,9 @@ bitslice_op::reduce_operand(unop_reduction_path_t path, jlm::rvsdg::output * arg { size_t pos = 0, n; std::vector arguments; - for (n = 0; n < node->ninputs(); n++) + for (n = 0; n < node.ninputs(); n++) { - auto argument = node->input(n)->origin(); + auto argument = node.input(n)->origin(); size_t base = pos; size_t nbits = std::static_pointer_cast(argument->Type())->nbits(); pos = pos + nbits; diff --git a/jlm/rvsdg/control.cpp b/jlm/rvsdg/control.cpp index 1a61bc8a3..50e9a11b1 100644 --- a/jlm/rvsdg/control.cpp +++ b/jlm/rvsdg/control.cpp @@ -119,7 +119,8 @@ match_op::reduce_operand(unop_reduction_path_t path, jlm::rvsdg::output * arg) c { if (path == unop_reduction_constant) { - auto op = static_cast(producer(arg)->GetOperation()); + auto op = static_cast( + static_cast(producer(arg))->GetOperation()); return jlm::rvsdg::control_constant( arg->region(), nalternatives(), diff --git a/jlm/rvsdg/node.hpp b/jlm/rvsdg/node.hpp index ee4cc5e06..35a780718 100644 --- a/jlm/rvsdg/node.hpp +++ b/jlm/rvsdg/node.hpp @@ -588,9 +588,6 @@ class Node explicit Node(Region * region); - [[nodiscard]] virtual const Operation & - GetOperation() const noexcept = 0; - inline bool has_users() const noexcept { @@ -1030,16 +1027,6 @@ divert_users(Node * node, const std::vector & outputs) node->output(n)->divert_users(outputs[n]); } -template -static inline bool -is(const Node * node) noexcept -{ - if (!node) - return false; - - return is(node->GetOperation()); -} - Node * producer(const jlm::rvsdg::output * output) noexcept; diff --git a/jlm/rvsdg/simple-node.cpp b/jlm/rvsdg/simple-node.cpp index 86c21a212..d1f88eaeb 100644 --- a/jlm/rvsdg/simple-node.cpp +++ b/jlm/rvsdg/simple-node.cpp @@ -121,9 +121,9 @@ NormalizeSimpleOperationCommonNodeElimination( { auto isCongruent = [&](const Node & node) { - auto & nodeOperation = node.GetOperation(); - return nodeOperation == operation && operands == rvsdg::operands(&node) - && &nodeOperation != &operation; + auto simpleNode = dynamic_cast(&node); + return simpleNode && simpleNode->GetOperation() == operation + && operands == rvsdg::operands(&node) && &simpleNode->GetOperation() != &operation; }; if (operands.empty()) diff --git a/jlm/rvsdg/simple-node.hpp b/jlm/rvsdg/simple-node.hpp index 5dd3159d8..0c64c94a9 100644 --- a/jlm/rvsdg/simple-node.hpp +++ b/jlm/rvsdg/simple-node.hpp @@ -36,7 +36,7 @@ class SimpleNode final : public Node output(size_t index) const noexcept; [[nodiscard]] const SimpleOperation & - GetOperation() const noexcept override; + GetOperation() const noexcept; Node * copy(rvsdg::Region * region, const std::vector & operands) const override; @@ -217,6 +217,17 @@ CreateOpNode(Region & region, OperatorArguments... operatorArguments) {}); } +template +static inline bool +is(const Node * node) noexcept +{ + if (!node) + return false; + + auto simple_node = dynamic_cast(node); + return simple_node && dynamic_cast(&simple_node->GetOperation()); +} + } #endif diff --git a/jlm/rvsdg/structural-node.hpp b/jlm/rvsdg/structural-node.hpp index 2fddbeec1..f13e835fb 100644 --- a/jlm/rvsdg/structural-node.hpp +++ b/jlm/rvsdg/structural-node.hpp @@ -31,6 +31,9 @@ class StructuralNode : public Node std::string DebugString() const override; + [[nodiscard]] virtual const StructuralOperation & + GetOperation() const noexcept = 0; + inline size_t nsubregions() const noexcept { diff --git a/tests/jlm/hls/opt/InvariantLambdaMemoryStateRemovalTests.cpp b/tests/jlm/hls/opt/InvariantLambdaMemoryStateRemovalTests.cpp index 5507debb2..fdf670b9c 100644 --- a/tests/jlm/hls/opt/InvariantLambdaMemoryStateRemovalTests.cpp +++ b/tests/jlm/hls/opt/InvariantLambdaMemoryStateRemovalTests.cpp @@ -63,7 +63,7 @@ TestEliminateSplitAndMergeNodes() assert(lambdaSubregion->nresults() == 1); assert(is(lambdaSubregion->result(0)->Type())); auto loadNode = - jlm::rvsdg::TryGetOwnerNode(*lambdaSubregion->result(0)->origin()); + jlm::rvsdg::TryGetOwnerNode(*lambdaSubregion->result(0)->origin()); assert(is(loadNode->GetOperation())); jlm::util::AssertedCast(loadNode->input(1)->origin()); @@ -136,14 +136,15 @@ TestInvariantMemoryState() assert(is(lambdaSubregion->result(0)->Type())); // Since there is more than one invariant memory state edge, the MemoryStateMerge node should // still exists - auto node = jlm::rvsdg::TryGetOwnerNode(*lambdaSubregion->result(0)->origin()); + auto node = + jlm::rvsdg::TryGetOwnerNode(*lambdaSubregion->result(0)->origin()); assert(is(node->GetOperation())); assert(node->ninputs() == 2); // Need to pass a load node to reach the MemoryStateSplit node - node = jlm::rvsdg::TryGetOwnerNode(*node->input(1)->origin()); + node = jlm::rvsdg::TryGetOwnerNode(*node->input(1)->origin()); assert(is(node->GetOperation())); // Check that the MemoryStateSplit node is still present - node = jlm::rvsdg::TryGetOwnerNode(*node->input(1)->origin()); + node = jlm::rvsdg::TryGetOwnerNode(*node->input(1)->origin()); assert(is(node->GetOperation())); return 0; diff --git a/tests/jlm/llvm/ir/operators/LoadTests.cpp b/tests/jlm/llvm/ir/operators/LoadTests.cpp index 8b5c753b0..d6b4aa731 100644 --- a/tests/jlm/llvm/ir/operators/LoadTests.cpp +++ b/tests/jlm/llvm/ir/operators/LoadTests.cpp @@ -66,12 +66,14 @@ TestCopy() auto loadResults = LoadNonVolatileOperation::Create(address1, { memoryState1 }, valueType, 4); // Act - auto node = jlm::rvsdg::TryGetOwnerNode(*loadResults[0]); + auto node = jlm::rvsdg::TryGetOwnerNode(*loadResults[0]); assert(is(node)); auto copiedNode = node->copy(&graph.GetRootRegion(), { address2, memoryState2 }); // Assert - assert(node->GetOperation() == copiedNode->GetOperation()); + assert( + node->GetOperation() + == jlm::util::AssertedCast(copiedNode)->GetOperation()); return 0; } @@ -638,7 +640,8 @@ NodeCopy() auto copiedNode = loadNode.copy(&graph.GetRootRegion(), { &address2, &iOState2, &memoryState2 }); // Assert - auto copiedOperation = dynamic_cast(&copiedNode->GetOperation()); + auto copiedOperation = dynamic_cast( + &jlm::util::AssertedCast(copiedNode)->GetOperation()); assert(copiedOperation != nullptr); assert(LoadOperation::AddressInput(*copiedNode).origin() == &address2); assert(LoadVolatileOperation::IOStateInput(*copiedNode).origin() == &iOState2); diff --git a/tests/jlm/llvm/ir/operators/StoreTests.cpp b/tests/jlm/llvm/ir/operators/StoreTests.cpp index a8c392c25..fac498ee0 100644 --- a/tests/jlm/llvm/ir/operators/StoreTests.cpp +++ b/tests/jlm/llvm/ir/operators/StoreTests.cpp @@ -194,11 +194,13 @@ TestCopy() auto storeResults = StoreNonVolatileOperation::Create(address1, value1, { memoryState1 }, 4); // Act - auto node = jlm::rvsdg::TryGetOwnerNode(*storeResults[0]); + auto node = jlm::rvsdg::TryGetOwnerNode(*storeResults[0]); auto copiedNode = node->copy(&graph.GetRootRegion(), { address2, value2, memoryState2 }); // Assert - assert(node->GetOperation() == copiedNode->GetOperation()); + assert( + node->GetOperation() + == jlm::util::AssertedCast(copiedNode)->GetOperation()); return 0; } @@ -240,9 +242,9 @@ TestStoreMuxNormalization() auto muxNode = jlm::rvsdg::TryGetOwnerNode(*ex.origin()); assert(is(muxNode)); assert(muxNode->ninputs() == 3); - auto n0 = jlm::rvsdg::TryGetOwnerNode(*muxNode->input(0)->origin()); - auto n1 = jlm::rvsdg::TryGetOwnerNode(*muxNode->input(1)->origin()); - auto n2 = jlm::rvsdg::TryGetOwnerNode(*muxNode->input(2)->origin()); + auto n0 = jlm::rvsdg::TryGetOwnerNode(*muxNode->input(0)->origin()); + auto n1 = jlm::rvsdg::TryGetOwnerNode(*muxNode->input(1)->origin()); + auto n2 = jlm::rvsdg::TryGetOwnerNode(*muxNode->input(2)->origin()); assert(jlm::rvsdg::is(n0->GetOperation())); assert(jlm::rvsdg::is(n1->GetOperation())); assert(jlm::rvsdg::is(n2->GetOperation())); diff --git a/tests/jlm/llvm/ir/operators/TestCall.cpp b/tests/jlm/llvm/ir/operators/TestCall.cpp index d7433fd36..11e0841c0 100644 --- a/tests/jlm/llvm/ir/operators/TestCall.cpp +++ b/tests/jlm/llvm/ir/operators/TestCall.cpp @@ -41,12 +41,14 @@ TestCopy() CallOperation::Create(function1, functionType, { value1, iOState1, memoryState1 }); // Act - auto node = jlm::rvsdg::TryGetOwnerNode(*callResults[0]); + auto node = jlm::rvsdg::TryGetOwnerNode(*callResults[0]); auto copiedNode = node->copy(&rvsdg.GetRootRegion(), { function2, value2, iOState2, memoryState2 }); // Assert - assert(node->GetOperation() == copiedNode->GetOperation()); + assert( + node->GetOperation() + == jlm::util::AssertedCast(copiedNode)->GetOperation()); } static void diff --git a/tests/jlm/llvm/ir/operators/test-sext.cpp b/tests/jlm/llvm/ir/operators/test-sext.cpp index 994ed3d3a..772f72ec9 100644 --- a/tests/jlm/llvm/ir/operators/test-sext.cpp +++ b/tests/jlm/llvm/ir/operators/test-sext.cpp @@ -36,7 +36,7 @@ test_bitunary_reduction() // Act ReduceNode( NormalizeUnaryOperation, - *jlm::rvsdg::TryGetOwnerNode(*ex.origin())); + *jlm::rvsdg::TryGetOwnerNode(*ex.origin())); graph.PruneNodes(); view(graph, stdout); @@ -68,7 +68,7 @@ test_bitbinary_reduction() // Act ReduceNode( NormalizeUnaryOperation, - *jlm::rvsdg::TryGetOwnerNode(*ex.origin())); + *jlm::rvsdg::TryGetOwnerNode(*ex.origin())); graph.PruneNodes(); view(graph, stdout); @@ -97,7 +97,9 @@ test_inverse_reduction() view(graph, stdout); // Act - ReduceNode(NormalizeUnaryOperation, *jlm::rvsdg::TryGetOwnerNode(*ex.origin())); + ReduceNode( + NormalizeUnaryOperation, + *jlm::rvsdg::TryGetOwnerNode(*ex.origin())); graph.PruneNodes(); view(graph, stdout); diff --git a/tests/jlm/llvm/opt/IfConversionTests.cpp b/tests/jlm/llvm/opt/IfConversionTests.cpp index cf5a87240..c28b77b08 100644 --- a/tests/jlm/llvm/opt/IfConversionTests.cpp +++ b/tests/jlm/llvm/opt/IfConversionTests.cpp @@ -127,10 +127,11 @@ EmptyGammaWithTwoSubregionsAndMatch() assert(selectNode->input(2)->origin() == falseValue); const auto eqNode = - jlm::rvsdg::TryGetOwnerNode(*selectNode->input(0)->origin()); + jlm::rvsdg::TryGetOwnerNode(*selectNode->input(0)->origin()); assert(eqNode && is(eqNode)); - auto constantNode = jlm::rvsdg::TryGetOwnerNode(*eqNode->input(0)->origin()); + auto constantNode = + jlm::rvsdg::TryGetOwnerNode(*eqNode->input(0)->origin()); if (constantNode) { assert(eqNode->input(1)->origin() == conditionValue); @@ -142,7 +143,7 @@ EmptyGammaWithTwoSubregionsAndMatch() else { assert(eqNode->input(0)->origin() == conditionValue); - constantNode = jlm::rvsdg::TryGetOwnerNode(*eqNode->input(1)->origin()); + constantNode = jlm::rvsdg::TryGetOwnerNode(*eqNode->input(1)->origin()); auto constantOperation = dynamic_cast(&constantNode->GetOperation()); assert(constantOperation); diff --git a/tests/jlm/mlir/TestIntegerOperationsJlmToMlirToJlm.cpp b/tests/jlm/mlir/TestIntegerOperationsJlmToMlirToJlm.cpp index eba39245c..5cd440237 100644 --- a/tests/jlm/mlir/TestIntegerOperationsJlmToMlirToJlm.cpp +++ b/tests/jlm/mlir/TestIntegerOperationsJlmToMlirToJlm.cpp @@ -97,7 +97,9 @@ TestIntegerBinaryOperation() bool foundBinaryOp = false; for (auto & node : region->Nodes()) { - auto convertedBinaryOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedBinaryOp = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedBinaryOp) { assert(convertedBinaryOp->nresults() == 1); @@ -225,7 +227,9 @@ TestIntegerComparisonOperation(const IntegerComparisonOpTest & tes bool foundCompOp = false; for (auto & node : region->Nodes()) { - auto convertedCompOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedCompOp = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedCompOp) { assert(convertedCompOp->nresults() == 1); diff --git a/tests/jlm/mlir/TestJlmToMlirToJlm.cpp b/tests/jlm/mlir/TestJlmToMlirToJlm.cpp index 8b13429b3..b79840e38 100644 --- a/tests/jlm/mlir/TestJlmToMlirToJlm.cpp +++ b/tests/jlm/mlir/TestJlmToMlirToJlm.cpp @@ -61,8 +61,9 @@ TestUndef() assert(region->nnodes() == 1); // Get the undef op - auto convertedUndef = - dynamic_cast(®ion->Nodes().begin()->GetOperation()); + auto convertedUndef = dynamic_cast( + &jlm::util::AssertedCast(&*region->Nodes().begin()) + ->GetOperation()); assert(convertedUndef != nullptr); @@ -139,7 +140,9 @@ TestAlloca() bool foundAlloca = false; for (auto & node : region->Nodes()) { - if (auto allocaOp = dynamic_cast(&node.GetOperation())) + auto simpleNode = dynamic_cast(&node); + if (auto allocaOp = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr)) { assert(allocaOp->alignment() == 4); @@ -241,14 +244,14 @@ TestLoad() assert(region->nnodes() == 1); auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); assert(convertedLambda->subregion()->nnodes() == 1); - assert(is( - convertedLambda->subregion()->Nodes().begin()->GetOperation())); + assert(is(jlm::util::AssertedCast( + &*convertedLambda->subregion()->Nodes().begin()) + ->GetOperation())); auto convertedLoad = convertedLambda->subregion()->Nodes().begin().ptr(); - auto loadOperation = - dynamic_cast(&convertedLoad->GetOperation()); + auto loadOperation = dynamic_cast( + &jlm::util::AssertedCast(convertedLoad)->GetOperation()); assert(loadOperation->GetAlignment() == 4); assert(loadOperation->NumMemoryStates() == 1); @@ -338,14 +341,14 @@ TestStore() assert(region->nnodes() == 1); auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); assert(convertedLambda->subregion()->nnodes() == 1); - assert(is( - convertedLambda->subregion()->Nodes().begin()->GetOperation())); + assert(is(jlm::util::AssertedCast( + &*convertedLambda->subregion()->Nodes().begin()) + ->GetOperation())); auto convertedStore = convertedLambda->subregion()->Nodes().begin().ptr(); - auto convertedStoreOperation = - dynamic_cast(&convertedStore->GetOperation()); + auto convertedStoreOperation = dynamic_cast( + &jlm::util::AssertedCast(convertedStore)->GetOperation()); assert(convertedStoreOperation->GetAlignment() == 4); assert(convertedStoreOperation->NumMemoryStates() == 1); @@ -425,12 +428,13 @@ TestSext() assert(region->nnodes() == 1); auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); assert(convertedLambda->subregion()->nnodes() == 1); - assert(is(convertedLambda->subregion()->Nodes().begin()->GetOperation())); - auto convertedSext = dynamic_cast( - &convertedLambda->subregion()->Nodes().begin()->GetOperation()); + auto convertedSext = + dynamic_cast(&jlm::util::AssertedCast( + &*convertedLambda->subregion()->Nodes().begin()) + ->GetOperation()); + assert(convertedSext); assert(convertedSext->ndstbits() == 64); assert(convertedSext->nsrcbits() == 32); @@ -501,9 +505,11 @@ TestSitofp() auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); assert(convertedLambda->subregion()->nnodes() == 1); - assert(is(convertedLambda->subregion()->Nodes().begin()->GetOperation())); auto convertedSitofp = dynamic_cast( - &convertedLambda->subregion()->Nodes().begin()->GetOperation()); + &jlm::util::AssertedCast( + &*convertedLambda->subregion()->Nodes().begin()) + ->GetOperation()); + assert(convertedSitofp); assert(jlm::rvsdg::is(*convertedSitofp->argument(0).get())); assert(jlm::rvsdg::is(*convertedSitofp->result(0).get())); @@ -560,9 +566,11 @@ TestConstantFP() auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); assert(convertedLambda->subregion()->nnodes() == 1); - assert(is(convertedLambda->subregion()->Nodes().begin()->GetOperation())); - auto convertedConst = dynamic_cast( - &convertedLambda->subregion()->Nodes().begin()->GetOperation()); + auto convertedConst = + dynamic_cast(&jlm::util::AssertedCast( + &*convertedLambda->subregion()->Nodes().begin()) + ->GetOperation()); + assert(convertedConst); assert(jlm::rvsdg::is(*convertedConst->result(0).get())); assert(convertedConst->constant().isExactlyValue(2.0)); @@ -643,7 +651,8 @@ TestFpBinary() assert(convertedLambda->subregion()->nnodes() == 1); auto node = convertedLambda->subregion()->Nodes().begin().ptr(); - auto convertedFpbin = jlm::util::AssertedCast(&node->GetOperation()); + auto convertedFpbin = jlm::util::AssertedCast( + &jlm::util::AssertedCast(&*node)->GetOperation()); assert(convertedFpbin->fpop() == binOp); assert(convertedFpbin->nresults() == 1); assert(convertedFpbin->narguments() == 2); @@ -732,8 +741,9 @@ TestGetElementPtr() assert(convertedLambda->subregion()->nnodes() == 1); auto op = convertedLambda->subregion()->Nodes().begin(); - assert(is(op->GetOperation())); - auto convertedGep = dynamic_cast(&op->GetOperation()); + auto convertedGep = dynamic_cast( + &jlm::util::AssertedCast(&*op)->GetOperation()); + assert(convertedGep); assert(is(convertedGep->GetPointeeType())); assert(is(convertedGep->result(0))); @@ -847,7 +857,8 @@ TestDelta() assert(convertedDelta->Section() == "section"); auto op = convertedDelta->subregion()->Nodes().begin(); - assert(is(op->GetOperation())); + assert(is( + jlm::util::AssertedCast(&*op)->GetOperation())); } } } @@ -912,7 +923,9 @@ TestConstantDataArray() bool foundConstantDataArray = false; for (auto & node : region->Nodes()) { - if (auto constantDataArray = dynamic_cast(&node.GetOperation())) + auto simpleNode = dynamic_cast(&node); + if (auto constantDataArray = dynamic_cast( + simpleNode ? &simpleNode->GetOperation() : nullptr)) { foundConstantDataArray = true; assert(constantDataArray->nresults() == 1); @@ -977,7 +990,8 @@ TestConstantAggregateZero() assert(region->nnodes() == 1); auto const convertedConstantAggregateZero = jlm::util::AssertedCast( - ®ion->Nodes().begin().ptr()->GetOperation()); + &jlm::util::AssertedCast(&*region->Nodes().begin().ptr()) + ->GetOperation()); assert(convertedConstantAggregateZero->nresults() == 1); assert(convertedConstantAggregateZero->narguments() == 0); auto resultType = convertedConstantAggregateZero->result(0); @@ -1044,7 +1058,9 @@ TestVarArgList() bool foundVarArgOp = false; for (auto & node : region->Nodes()) { - auto convertedVarArgOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedVarArgOp = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedVarArgOp) { assert(convertedVarArgOp->nresults() == 1); @@ -1119,7 +1135,9 @@ TestFNeg() bool foundFNegOp = false; for (auto & node : region->Nodes()) { - auto convertedFNegOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedFNegOp = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedFNegOp) { assert(convertedFNegOp->nresults() == 1); @@ -1197,7 +1215,9 @@ TestFPExt() bool foundFPExtOp = false; for (auto & node : region->Nodes()) { - auto convertedFPExtOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedFPExtOp = dynamic_cast( + simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedFPExtOp) { assert(convertedFPExtOp->nresults() == 1); @@ -1274,7 +1294,9 @@ TestTrunc() bool foundTruncOp = false; for (auto & node : region->Nodes()) { - auto convertedTruncOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedTruncOp = dynamic_cast( + simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedTruncOp) { assert(convertedTruncOp->nresults() == 1); @@ -1365,12 +1387,13 @@ TestFree() assert(region->nnodes() == 1); auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); assert(convertedLambda->subregion()->nnodes() == 1); - assert(is(convertedLambda->subregion()->Nodes().begin()->GetOperation())); - auto convertedFree = dynamic_cast( - &convertedLambda->subregion()->Nodes().begin()->GetOperation()); + auto convertedFree = + dynamic_cast(&jlm::util::AssertedCast( + &*convertedLambda->subregion()->Nodes().begin()) + ->GetOperation()); + assert(convertedFree); assert(convertedFree->narguments() == 3); assert(convertedFree->nresults() == 2); @@ -1626,9 +1649,11 @@ TestIOBarrier() // Find the IOBarrier in the lambda subregion bool foundIOBarrier = false; - for (auto & lambdaNode : lambdaOperation->subregion()->Nodes()) + for (auto & node : lambdaOperation->subregion()->Nodes()) { - auto ioBarrierOp = dynamic_cast(&lambdaNode.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto ioBarrierOp = dynamic_cast( + simpleNode ? &simpleNode->GetOperation() : nullptr); if (ioBarrierOp) { foundIOBarrier = true; @@ -1712,7 +1737,9 @@ TestMalloc() bool foundMallocOp = false; for (auto & node : region->Nodes()) { - auto convertedMallocOp = dynamic_cast(&node.GetOperation()); + auto simpleNode = dynamic_cast(&node); + auto convertedMallocOp = + dynamic_cast(simpleNode ? &simpleNode->GetOperation() : nullptr); if (convertedMallocOp) { assert(convertedMallocOp->nresults() == 2); diff --git a/tests/jlm/mlir/frontend/TestMlirToJlmConverter.cpp b/tests/jlm/mlir/frontend/TestMlirToJlmConverter.cpp index 7d6cd4631..a18392d63 100644 --- a/tests/jlm/mlir/frontend/TestMlirToJlmConverter.cpp +++ b/tests/jlm/mlir/frontend/TestMlirToJlmConverter.cpp @@ -267,17 +267,15 @@ TestDivOperation() // Get the lambda block auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); + assert(is(convertedLambda->GetOperation())); // 2 Constants + 1 DivUIOp assert(convertedLambda->subregion()->nnodes() == 3); // Traverse the rvsgd graph upwards to check connections - jlm::rvsdg::node_output * lambdaResultOriginNodeOuput; - assert( - lambdaResultOriginNodeOuput = dynamic_cast( - convertedLambda->subregion()->result(0)->origin())); - Node * lambdaResultOriginNode = lambdaResultOriginNodeOuput->node(); + auto lambdaResultOriginNode = jlm::rvsdg::TryGetOwnerNode( + *convertedLambda->subregion()->result(0)->origin()); + assert(lambdaResultOriginNode); assert(is(lambdaResultOriginNode->GetOperation())); assert(lambdaResultOriginNode->ninputs() == 2); @@ -441,7 +439,7 @@ TestCompZeroExt() // Get the lambda block auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); + assert(is(convertedLambda->GetOperation())); // 2 Constants + AddOp + CompOp + ZeroExtOp assert(convertedLambda->subregion()->nnodes() == 5); @@ -646,7 +644,7 @@ TestMatchOp() // Get the lambda block auto convertedLambda = jlm::util::AssertedCast(region->Nodes().begin().ptr()); - assert(is(convertedLambda)); + assert(is(convertedLambda->GetOperation())); auto lambdaRegion = convertedLambda->subregion(); diff --git a/tests/jlm/rvsdg/SimpleOperationTests.cpp b/tests/jlm/rvsdg/SimpleOperationTests.cpp index 7657e1e4b..f006fa851 100644 --- a/tests/jlm/rvsdg/SimpleOperationTests.cpp +++ b/tests/jlm/rvsdg/SimpleOperationTests.cpp @@ -51,10 +51,18 @@ NormalizeSimpleOperationCne_NodesWithoutOperands() }; // Act - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exNullaryValueNode1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exNullaryValueNode2.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exNullaryStateNode1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exNullaryStateNode2.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exNullaryValueNode1.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exNullaryValueNode2.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exNullaryStateNode1.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exNullaryStateNode2.origin())); graph.PruneNodes(); view(graph, stdout); @@ -106,10 +114,10 @@ NormalizeSimpleOperationCne_NodesWithOperands() }; // Act - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exValueNode1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exValueNode2.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exStateNode1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exStateNode2.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*exValueNode1.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*exValueNode2.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*exStateNode1.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*exStateNode2.origin())); graph.PruneNodes(); view(graph, stdout); @@ -163,10 +171,18 @@ NormalizeSimpleOperationCne_Failure() }; // Act - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exNullaryValueNode.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exNullaryStateNode.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exUnaryValueNode.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*exUnaryStateNode.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exNullaryValueNode.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exNullaryStateNode.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exUnaryValueNode.origin())); + ReduceNode( + NormalizeCne, + *TryGetOwnerNode(*exUnaryStateNode.origin())); graph.PruneNodes(); view(graph, stdout); diff --git a/tests/jlm/rvsdg/bitstring/bitstring.cpp b/tests/jlm/rvsdg/bitstring/bitstring.cpp index 06f897ae0..cfdea0647 100644 --- a/tests/jlm/rvsdg/bitstring/bitstring.cpp +++ b/tests/jlm/rvsdg/bitstring/bitstring.cpp @@ -40,8 +40,8 @@ types_bitstring_arithmetic_test_bitand() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitand_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, +1)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitand_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, +1)); return 0; } @@ -86,11 +86,11 @@ types_bitstring_arithmetic_test_bitashr() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitashr_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 4)); - assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == int_constant_op(32, 0)); - assert(TryGetOwnerNode(*ex3.origin())->GetOperation() == int_constant_op(32, -4)); - assert(TryGetOwnerNode(*ex4.origin())->GetOperation() == int_constant_op(32, -1)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitashr_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 4)); + assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == int_constant_op(32, 0)); + assert(TryGetOwnerNode(*ex3.origin())->GetOperation() == int_constant_op(32, -4)); + assert(TryGetOwnerNode(*ex4.origin())->GetOperation() == int_constant_op(32, -1)); return 0; } @@ -118,7 +118,7 @@ types_bitstring_arithmetic_test_bitdifference() view(&graph.GetRootRegion(), stdout); // Act - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsub_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsub_op(32)); return 0; } @@ -152,9 +152,9 @@ types_bitstring_arithmetic_test_bitnegate() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitneg_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -3)); - assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == int_constant_op(32, 3)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitneg_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -3)); + assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == int_constant_op(32, 3)); return 0; } @@ -188,9 +188,9 @@ types_bitstring_arithmetic_test_bitnot() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitnot_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -4)); - assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == int_constant_op(32, 3)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitnot_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -4)); + assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == int_constant_op(32, 3)); return 0; } @@ -224,8 +224,8 @@ types_bitstring_arithmetic_test_bitor() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitor_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 7)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitor_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 7)); return 0; } @@ -259,8 +259,8 @@ types_bitstring_arithmetic_test_bitproduct() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitmul_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 15)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitmul_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 15)); return 0; } @@ -288,7 +288,7 @@ types_bitstring_arithmetic_test_bitshiproduct() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsmulh_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsmulh_op(32)); return 0; } @@ -326,9 +326,9 @@ types_bitstring_arithmetic_test_bitshl() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitshl_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 64)); - assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == uint_constant_op(32, 0)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitshl_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 64)); + assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == uint_constant_op(32, 0)); return 0; } @@ -366,9 +366,9 @@ types_bitstring_arithmetic_test_bitshr() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitshr_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 4)); - assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == uint_constant_op(32, 0)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitshr_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == uint_constant_op(32, 4)); + assert(TryGetOwnerNode(*ex2.origin())->GetOperation() == uint_constant_op(32, 0)); return 0; } @@ -402,8 +402,8 @@ types_bitstring_arithmetic_test_bitsmod() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsmod_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -1)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsmod_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -1)); return 0; } @@ -437,8 +437,8 @@ types_bitstring_arithmetic_test_bitsquotient() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsdiv_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -2)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsdiv_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, -2)); return 0; } @@ -472,8 +472,8 @@ types_bitstring_arithmetic_test_bitsum() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitadd_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 8)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitadd_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 8)); return 0; } @@ -501,7 +501,7 @@ types_bitstring_arithmetic_test_bituhiproduct() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitumulh_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitumulh_op(32)); return 0; } @@ -535,8 +535,8 @@ types_bitstring_arithmetic_test_bitumod() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitumod_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 1)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitumod_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 1)); return 0; } @@ -570,8 +570,8 @@ types_bitstring_arithmetic_test_bituquotient() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitudiv_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 2)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitudiv_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 2)); return 0; } @@ -605,8 +605,8 @@ types_bitstring_arithmetic_test_bitxor() view(&graph.GetRootRegion(), stdout); // Arrange - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitxor_op(32)); - assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 6)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitxor_op(32)); + assert(TryGetOwnerNode(*ex1.origin())->GetOperation() == int_constant_op(32, 6)); return 0; } @@ -614,7 +614,7 @@ types_bitstring_arithmetic_test_bitxor() static inline void expect_static_true(jlm::rvsdg::output * port) { - auto node = jlm::rvsdg::TryGetOwnerNode(*port); + auto node = jlm::rvsdg::TryGetOwnerNode(*port); auto op = dynamic_cast(&node->GetOperation()); assert(op && op->value().nbits() == 1 && op->value().str() == "1"); } @@ -622,7 +622,7 @@ expect_static_true(jlm::rvsdg::output * port) static inline void expect_static_false(jlm::rvsdg::output * port) { - auto node = jlm::rvsdg::TryGetOwnerNode(*port); + auto node = jlm::rvsdg::TryGetOwnerNode(*port); auto op = dynamic_cast(&node->GetOperation()); assert(op && op->value().nbits() == 1 && op->value().str() == "0"); } @@ -662,10 +662,10 @@ types_bitstring_comparison_test_bitequal() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == biteq_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == biteq_op(32)); expect_static_true(ex1.origin()); expect_static_false(ex2.origin()); - assert(TryGetOwnerNode(*ex3.origin())->GetOperation() == biteq_op(32)); + assert(TryGetOwnerNode(*ex3.origin())->GetOperation() == biteq_op(32)); return 0; } @@ -705,10 +705,10 @@ types_bitstring_comparison_test_bitnotequal() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitne_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitne_op(32)); expect_static_false(ex1.origin()); expect_static_true(ex2.origin()); - assert(TryGetOwnerNode(*ex3.origin())->GetOperation() == bitne_op(32)); + assert(TryGetOwnerNode(*ex3.origin())->GetOperation() == bitne_op(32)); return 0; } @@ -753,7 +753,7 @@ types_bitstring_comparison_test_bitsgreater() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsgt_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsgt_op(32)); expect_static_false(ex1.origin()); expect_static_true(ex2.origin()); expect_static_false(ex3.origin()); @@ -804,7 +804,7 @@ types_bitstring_comparison_test_bitsgreatereq() view(&graph.GetRootRegion(), stdout); // Arrange - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsge_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsge_op(32)); expect_static_false(ex1.origin()); expect_static_true(ex2.origin()); expect_static_true(ex3.origin()); @@ -854,7 +854,7 @@ types_bitstring_comparison_test_bitsless() view(&graph.GetRootRegion(), stdout); // Arrange - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitslt_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitslt_op(32)); expect_static_true(ex1.origin()); expect_static_false(ex2.origin()); expect_static_false(ex3.origin()); @@ -906,7 +906,7 @@ types_bitstring_comparison_test_bitslesseq() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsle_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitsle_op(32)); expect_static_true(ex1.origin()); expect_static_true(ex2.origin()); expect_static_false(ex3.origin()); @@ -955,7 +955,7 @@ types_bitstring_comparison_test_bitugreater() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitugt_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitugt_op(32)); expect_static_false(ex1.origin()); expect_static_true(ex2.origin()); expect_static_false(ex3.origin()); @@ -1006,7 +1006,7 @@ types_bitstring_comparison_test_bitugreatereq() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bituge_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bituge_op(32)); expect_static_false(ex1.origin()); expect_static_true(ex2.origin()); expect_static_true(ex3.origin()); @@ -1056,7 +1056,7 @@ types_bitstring_comparison_test_bituless() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitult_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitult_op(32)); expect_static_true(ex1.origin()); expect_static_false(ex2.origin()); expect_static_false(ex3.origin()); @@ -1108,7 +1108,7 @@ types_bitstring_comparison_test_bitulesseq() view(&graph.GetRootRegion(), stdout); // Assert - assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitule_op(32)); + assert(TryGetOwnerNode(*ex0.origin())->GetOperation() == bitule_op(32)); expect_static_true(ex1.origin()); expect_static_true(ex2.origin()); expect_static_false(ex3.origin()); @@ -1179,19 +1179,19 @@ types_bitstring_test_constant() assert(b1.GetOperation() == uint_constant_op(8, 204)); assert(b1.GetOperation() == int_constant_op(8, -52)); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex2.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex3.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex4.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex1.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex2.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex3.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex4.origin())); assert(ex1.origin() == ex2.origin()); assert(ex1.origin() == ex3.origin()); - const auto node1 = TryGetOwnerNode(*ex1.origin()); + const auto node1 = TryGetOwnerNode(*ex1.origin()); assert(node1->GetOperation() == uint_constant_op(8, 204)); assert(node1->GetOperation() == int_constant_op(8, -52)); - const auto node4 = TryGetOwnerNode(*ex4.origin()); + const auto node4 = TryGetOwnerNode(*ex4.origin()); assert(node4->GetOperation() == uint_constant_op(9, 204)); assert(node4->GetOperation() == int_constant_op(9, 204)); @@ -1230,14 +1230,14 @@ types_bitstring_test_normalize() // Act ReduceNode(FlattenAssociativeBinaryOperation, sum1); - auto & flattenedBinaryNode = *TryGetOwnerNode(*ex.origin()); + auto & flattenedBinaryNode = *TryGetOwnerNode(*ex.origin()); ReduceNode(NormalizeFlattenedBinaryOperation, flattenedBinaryNode); graph.PruneNodes(); view(&graph.GetRootRegion(), stdout); // Assert - auto node = TryGetOwnerNode(*ex.origin()); + auto node = TryGetOwnerNode(*ex.origin()); assert(node->GetOperation() == bitadd_op(32)); assert(node->ninputs() == 2); auto op1 = node->input(0)->origin(); @@ -1249,7 +1249,7 @@ types_bitstring_test_normalize() op2 = tmp; } /* FIXME: the graph traversers are currently broken, that is why it won't normalize */ - assert(TryGetOwnerNode(*op1)->GetOperation() == int_constant_op(32, 3 + 4)); + assert(TryGetOwnerNode(*op1)->GetOperation() == int_constant_op(32, 3 + 4)); assert(op2 == imp); view(&graph.GetRootRegion(), stdout); @@ -1260,7 +1260,7 @@ types_bitstring_test_normalize() static void assert_constant(jlm::rvsdg::output * bitstr, size_t nbits, const char bits[]) { - auto node = jlm::rvsdg::TryGetOwnerNode(*bitstr); + auto node = jlm::rvsdg::TryGetOwnerNode(*bitstr); auto op = dynamic_cast(node->GetOperation()); assert(op.value() == jlm::rvsdg::bitvalue_repr(std::string(bits, nbits).c_str())); } @@ -1402,7 +1402,7 @@ ConcatOfSliceReduction() view(&graph.GetRootRegion(), stdout); // Assert - const auto sliceNode = TryGetOwnerNode(*ex.origin()); + const auto sliceNode = TryGetOwnerNode(*ex.origin()); assert(sliceNode->GetOperation() == bitslice_op(bit16Type, 0, 16)); assert(sliceNode->input(0)->origin() == x); @@ -1435,7 +1435,7 @@ SliceOfConstant() view(graph, stdout); // Assert - const auto node = TryGetOwnerNode(*ex.origin()); + const auto node = TryGetOwnerNode(*ex.origin()); auto & operation = dynamic_cast(node->GetOperation()); assert(operation.value() == bitvalue_repr("1101")); @@ -1468,7 +1468,7 @@ SliceOfSlice() view(graph, stdout); // Assert - const auto node = TryGetOwnerNode(*ex.origin()); + const auto node = TryGetOwnerNode(*ex.origin()); const auto operation = dynamic_cast(&node->GetOperation()); assert(operation->low() == 3 && operation->high() == 5); @@ -1527,11 +1527,11 @@ SliceOfConcat() // Act ReduceNode(NormalizeUnaryOperation, sliceNode); - auto concatNode = TryGetOwnerNode(*ex.origin()); + auto concatNode = TryGetOwnerNode(*ex.origin()); ReduceNode( NormalizeUnaryOperation, - *TryGetOwnerNode(*concatNode->input(0)->origin())); - concatNode = TryGetOwnerNode(*ex.origin()); + *TryGetOwnerNode(*concatNode->input(0)->origin())); + concatNode = TryGetOwnerNode(*ex.origin()); ReduceNode(NormalizeBinaryOperation, *concatNode); graph.PruneNodes(); @@ -1565,13 +1565,13 @@ ConcatFlattening() view(graph, stdout); // Act - const auto concatNode = TryGetOwnerNode(*ex.origin()); + const auto concatNode = TryGetOwnerNode(*ex.origin()); ReduceNode(FlattenBitConcatOperation, *concatNode); view(graph, stdout); // Assert - auto node = TryGetOwnerNode(*ex.origin()); + auto node = TryGetOwnerNode(*ex.origin()); assert(dynamic_cast(&node->GetOperation())); assert(node->ninputs() == 3); assert(node->input(0)->origin() == x); @@ -1637,7 +1637,7 @@ ConcatOfSlices() // Act ReduceNode(NormalizeBinaryOperation, concatNode); - ReduceNode(NormalizeUnaryOperation, *TryGetOwnerNode(*ex.origin())); + ReduceNode(NormalizeUnaryOperation, *TryGetOwnerNode(*ex.origin())); graph.PruneNodes(); view(graph, stdout); @@ -1666,10 +1666,10 @@ ConcatOfConstants() view(graph, stdout); // Act - ReduceNode(NormalizeBinaryOperation, *TryGetOwnerNode(*ex.origin())); + ReduceNode(NormalizeBinaryOperation, *TryGetOwnerNode(*ex.origin())); // Assert - auto node = TryGetOwnerNode(*ex.origin()); + auto node = TryGetOwnerNode(*ex.origin()); auto operation = dynamic_cast(node->GetOperation()); assert(operation.value() == bitvalue_repr("0011011111001000")); @@ -1709,8 +1709,8 @@ ConcatCne() view(graph, stdout); // Act - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex2.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex1.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex2.origin())); graph.PruneNodes(); view(graph, stdout); @@ -1752,8 +1752,8 @@ SliceCne() view(graph, stdout); // Act - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex2.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex1.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*ex2.origin())); graph.PruneNodes(); view(graph, stdout); diff --git a/tests/jlm/rvsdg/test-cse.cpp b/tests/jlm/rvsdg/test-cse.cpp index eb2d54cb7..87b8594ca 100644 --- a/tests/jlm/rvsdg/test-cse.cpp +++ b/tests/jlm/rvsdg/test-cse.cpp @@ -41,29 +41,29 @@ test_main() auto & e4 = jlm::tests::GraphExport::Create(*o4, "o4"); // Act & Assert - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e1.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e2.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e3.origin())); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e4.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e1.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e2.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e3.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e4.origin())); assert(e1.origin() == e3.origin()); assert(e2.origin() == e4.origin()); auto o5 = jlm::tests::create_testop(&graph.GetRootRegion(), {}, { valueType })[0]; auto & e5 = jlm::tests::GraphExport::Create(*o5, "o5"); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e5.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e5.origin())); assert(e5.origin() == e1.origin()); auto o6 = jlm::tests::create_testop(&graph.GetRootRegion(), { i }, { valueType })[0]; auto & e6 = jlm::tests::GraphExport::Create(*o6, "o6"); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e6.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e6.origin())); assert(e6.origin() == e2.origin()); auto o7 = jlm::tests::create_testop(&graph.GetRootRegion(), {}, { valueType })[0]; auto & e7 = jlm::tests::GraphExport::Create(*o7, "o7"); assert(e7.origin() != e1.origin()); - ReduceNode(NormalizeCne, *TryGetOwnerNode(*e7.origin())); + ReduceNode(NormalizeCne, *TryGetOwnerNode(*e7.origin())); assert(e7.origin() == e1.origin()); return 0; diff --git a/tests/jlm/rvsdg/test-gamma.cpp b/tests/jlm/rvsdg/test-gamma.cpp index dde2897b6..12833a262 100644 --- a/tests/jlm/rvsdg/test-gamma.cpp +++ b/tests/jlm/rvsdg/test-gamma.cpp @@ -152,7 +152,7 @@ test_control_constant_reduction() view(&graph.GetRootRegion(), stdout); // Assert - auto match = TryGetOwnerNode(*ex1.origin()); + auto match = TryGetOwnerNode(*ex1.origin()); assert(match && is(match->GetOperation())); auto & match_op = to_match_op(match->GetOperation()); assert(match_op.default_alternative() == 0);