Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1153,10 +1153,10 @@ RhlsToFirrtlConverter::MlirGenHlsLocalMem(const jlm::rvsdg::SimpleNode * node)
{
auto lmem_op = dynamic_cast<const local_mem_op *>(&(node->GetOperation()));
JLM_ASSERT(lmem_op);
auto res_node = rvsdg::TryGetOwnerNode<rvsdg::Node>(**node->output(0)->begin());
auto res_node = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(**node->output(0)->begin());
auto res_op = dynamic_cast<const local_mem_resp_op *>(&res_node->GetOperation());
JLM_ASSERT(res_op);
auto req_node = rvsdg::TryGetOwnerNode<rvsdg::Node>(**node->output(1)->begin());
auto req_node = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(**node->output(1)->begin());
auto req_op = dynamic_cast<const local_mem_req_op *>(&req_node->GetOperation());
JLM_ASSERT(req_op);
// Create the module and its input/output ports - we use a non-standard way here
Expand Down Expand Up @@ -2788,8 +2788,8 @@ RhlsToFirrtlConverter::createInstances(
{
if (auto sn = dynamic_cast<jlm::rvsdg::SimpleNode *>(node))
{
if (dynamic_cast<const local_mem_req_op *>(&(node->GetOperation()))
|| dynamic_cast<const local_mem_resp_op *>(&(node->GetOperation())))
if (dynamic_cast<const local_mem_req_op *>(&(sn->GetOperation()))
|| dynamic_cast<const local_mem_resp_op *>(&(sn->GetOperation())))
{
// these are virtual - connections go to local_mem instead
continue;
Expand Down Expand Up @@ -3993,7 +3993,7 @@ RhlsToFirrtlConverter::GetFirrtlType(const jlm::rvsdg::Type * type)
}

std::string
RhlsToFirrtlConverter::GetModuleName(const rvsdg::Node * node)
RhlsToFirrtlConverter::GetModuleName(const rvsdg::SimpleNode * node)
{

std::string append = "";
Expand Down Expand Up @@ -4085,7 +4085,7 @@ RhlsToFirrtlConverter::IsIdentityMapping(const jlm::rvsdg::match_op & op)
void
RhlsToFirrtlConverter::WriteModuleToFile(
const circt::firrtl::FModuleOp fModuleOp,
const rvsdg::Node * node)
const rvsdg::SimpleNode * node)
{
if (!fModuleOp)
return;
Expand Down
4 changes: 2 additions & 2 deletions jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class RhlsToFirrtlConverter : public BaseHLS
MlirGen(const rvsdg::LambdaNode * lamdaNode);

void
WriteModuleToFile(const circt::firrtl::FModuleOp fModuleOp, const rvsdg::Node * node);
WriteModuleToFile(const circt::firrtl::FModuleOp fModuleOp, const rvsdg::SimpleNode * node);

void
WriteCircuitToFile(const circt::firrtl::CircuitOp circuit, std::string name);
Expand Down Expand Up @@ -283,7 +283,7 @@ class RhlsToFirrtlConverter : public BaseHLS
circt::firrtl::FIRRTLBaseType
GetFirrtlType(const jlm::rvsdg::Type * type);
std::string
GetModuleName(const rvsdg::Node * node);
GetModuleName(const rvsdg::SimpleNode * node);
bool
IsIdentityMapping(const jlm::rvsdg::match_op & op);

Expand Down
14 changes: 9 additions & 5 deletions jlm/hls/backend/rhls2firrtl/dot-hls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,10 @@ DotHLS::loop_to_dot(hls::loop_node * ln)
dot << "{rank=same ";
for (auto node : rvsdg::TopDownTraverser(sr))
{
auto mx = dynamic_cast<const hls::mux_op *>(&node->GetOperation());
auto lc = dynamic_cast<const hls::loop_constant_buffer_op *>(&node->GetOperation());
auto simpleNode = dynamic_cast<const rvsdg::SimpleNode *>(node);
auto mx = dynamic_cast<const hls::mux_op *>(simpleNode ? &simpleNode->GetOperation() : nullptr);
auto lc = dynamic_cast<const hls::loop_constant_buffer_op *>(
simpleNode ? &simpleNode->GetOperation() : nullptr);
if ((mx && !mx->discarding && mx->loop) || lc)
{
dot << get_node_name(node) << " ";
Expand All @@ -260,7 +262,9 @@ DotHLS::loop_to_dot(hls::loop_node * ln)
dot << "{rank=same ";
for (auto node : rvsdg::TopDownTraverser(sr))
{
auto br = dynamic_cast<const hls::branch_op *>(&node->GetOperation());
auto simpleNode = dynamic_cast<const rvsdg::SimpleNode *>(node);
auto br =
dynamic_cast<const hls::branch_op *>(simpleNode ? &simpleNode->GetOperation() : nullptr);
if (br && br->loop)
{
dot << get_node_name(node) << " ";
Expand All @@ -272,9 +276,9 @@ DotHLS::loop_to_dot(hls::loop_node * ln)
// do edges outside in order not to pull other nodes into the cluster
for (auto node : rvsdg::TopDownTraverser(sr))
{
if (dynamic_cast<jlm::rvsdg::SimpleNode *>(node))
if (auto simpleNode = dynamic_cast<jlm::rvsdg::SimpleNode *>(node))
{
auto mx = dynamic_cast<const hls::mux_op *>(&node->GetOperation());
auto mx = dynamic_cast<const hls::mux_op *>(&simpleNode->GetOperation());
auto node_name = get_node_name(node);
for (size_t i = 0; i < node->ninputs(); ++i)
{
Expand Down
25 changes: 14 additions & 11 deletions jlm/hls/backend/rvsdg2rhls/add-prints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,23 @@ convert_prints(
convert_prints(structnode->subregion(n), printf, functionType);
}
}
else if (auto po = dynamic_cast<const print_op *>(&(node->GetOperation())))
else if (auto simpleNode = dynamic_cast<rvsdg::SimpleNode *>(node))
{
auto printf_local = route_to_region_rvsdg(printf, region); // TODO: prevent repetition?
auto & constantNode = llvm::IntegerConstantOperation::Create(*region, 64, po->id());
jlm::rvsdg::output * val = node->input(0)->origin();
if (*val->Type() != *jlm::rvsdg::bittype::Create(64))
if (auto po = dynamic_cast<const print_op *>(&(simpleNode->GetOperation())))
{
auto bt = std::dynamic_pointer_cast<const rvsdg::bittype>(val->Type());
JLM_ASSERT(bt);
val = &llvm::ZExtOperation::Create(*val, rvsdg::bittype::Create(64));
auto printf_local = route_to_region_rvsdg(printf, region); // TODO: prevent repetition?
auto & constantNode = llvm::IntegerConstantOperation::Create(*region, 64, po->id());
jlm::rvsdg::output * val = node->input(0)->origin();
if (*val->Type() != *jlm::rvsdg::bittype::Create(64))
{
auto bt = std::dynamic_pointer_cast<const jlm::rvsdg::bittype>(val->Type());
JLM_ASSERT(bt);
val = &llvm::ZExtOperation::Create(*val, rvsdg::bittype::Create(64));
}
llvm::CallOperation::Create(printf_local, functionType, { constantNode.output(0), val });
node->output(0)->divert_users(node->input(0)->origin());
jlm::rvsdg::remove(node);
}
llvm::CallOperation::Create(printf_local, functionType, { constantNode.output(0), val });
node->output(0)->divert_users(node->input(0)->origin());
jlm::rvsdg::remove(node);
}
}
}
Expand Down
164 changes: 83 additions & 81 deletions jlm/hls/backend/rvsdg2rhls/alloca-conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,99 +127,101 @@ alloca_conv(rvsdg::Region * region)
alloca_conv(structnode->subregion(n));
}
}
else if (auto po = dynamic_cast<const jlm::llvm::alloca_op *>(&(node->GetOperation())))
else if (auto simpleNode = dynamic_cast<rvsdg::SimpleNode *>(node))
{
// ensure that the size is one
JLM_ASSERT(node->ninputs() == 1);
auto constant_output = dynamic_cast<jlm::rvsdg::node_output *>(node->input(0)->origin());
JLM_ASSERT(constant_output);
auto constant_operation = dynamic_cast<const llvm::IntegerConstantOperation *>(
&constant_output->node()->GetOperation());
JLM_ASSERT(constant_operation);
JLM_ASSERT(constant_operation->Representation().to_uint() == 1);
// ensure that the alloca is an array type
auto at = std::dynamic_pointer_cast<const llvm::ArrayType>(po->ValueType());
JLM_ASSERT(at);
// detect loads and stores attached to alloca
TraceAllocaUses ta(node->output(0));
// create memory + response
auto mem_outs = local_mem_op::create(at, node->region());
auto resp_outs = local_mem_resp_op::create(*mem_outs[0], ta.load_nodes.size());
std::cout << "alloca converted " << at->debug_string() << std::endl;
// replace gep outputs (convert pointer to index calculation)
// replace loads and stores
std::vector<jlm::rvsdg::output *> load_addrs;
for (auto l : ta.load_nodes)
if (auto po = dynamic_cast<const jlm::llvm::alloca_op *>(&(simpleNode->GetOperation())))
{
auto index = gep_to_index(l->input(0)->origin());
auto response = route_response_rhls(l->region(), resp_outs.front());
resp_outs.erase(resp_outs.begin());
std::vector<jlm::rvsdg::output *> states;
for (size_t i = 1; i < l->ninputs(); ++i)
// ensure that the size is one
JLM_ASSERT(node->ninputs() == 1);
auto & constant_node =
rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(*node->input(0)->origin());
auto constant_operation =
util::AssertedCast<const llvm::IntegerConstantOperation>(&constant_node.GetOperation());
JLM_ASSERT(constant_operation->Representation().to_uint() == 1);
// ensure that the alloca is an array type
auto at = std::dynamic_pointer_cast<const llvm::ArrayType>(po->ValueType());
JLM_ASSERT(at);
// detect loads and stores attached to alloca
TraceAllocaUses ta(node->output(0));
// create memory + response
auto mem_outs = local_mem_op::create(at, node->region());
auto resp_outs = local_mem_resp_op::create(*mem_outs[0], ta.load_nodes.size());
std::cout << "alloca converted " << at->debug_string() << std::endl;
// replace gep outputs (convert pointer to index calculation)
// replace loads and stores
std::vector<jlm::rvsdg::output *> load_addrs;
for (auto l : ta.load_nodes)
{
states.push_back(l->input(i)->origin());
auto index = gep_to_index(l->input(0)->origin());
auto response = route_response_rhls(l->region(), resp_outs.front());
resp_outs.erase(resp_outs.begin());
std::vector<jlm::rvsdg::output *> states;
for (size_t i = 1; i < l->ninputs(); ++i)
{
states.push_back(l->input(i)->origin());
}
auto load_outs = local_load_op::create(*index, states, *response);
auto nn = dynamic_cast<jlm::rvsdg::node_output *>(load_outs[0])->node();
for (size_t i = 0; i < l->noutputs(); ++i)
{
l->output(i)->divert_users(nn->output(i));
}
remove(l);
auto addr = route_request_rhls(node->region(), load_outs.back());
load_addrs.push_back(addr);
}
auto load_outs = local_load_op::create(*index, states, *response);
auto nn = dynamic_cast<jlm::rvsdg::node_output *>(load_outs[0])->node();
for (size_t i = 0; i < l->noutputs(); ++i)
std::vector<jlm::rvsdg::output *> store_operands;
for (auto s : ta.store_nodes)
{
l->output(i)->divert_users(nn->output(i));
auto index = gep_to_index(s->input(0)->origin());
std::vector<jlm::rvsdg::output *> states;
for (size_t i = 2; i < s->ninputs(); ++i)
{
states.push_back(s->input(i)->origin());
}
auto store_outs = local_store_op::create(*index, *s->input(1)->origin(), states);
auto nn = dynamic_cast<jlm::rvsdg::node_output *>(store_outs[0])->node();
for (size_t i = 0; i < s->noutputs(); ++i)
{
s->output(i)->divert_users(nn->output(i));
}
remove(s);
auto addr = route_request_rhls(node->region(), store_outs[store_outs.size() - 2]);
auto data = route_request_rhls(node->region(), store_outs.back());
store_operands.push_back(addr);
store_operands.push_back(data);
}
remove(l);
auto addr = route_request_rhls(node->region(), load_outs.back());
load_addrs.push_back(addr);
}
std::vector<jlm::rvsdg::output *> store_operands;
for (auto s : ta.store_nodes)
{
auto index = gep_to_index(s->input(0)->origin());
std::vector<jlm::rvsdg::output *> states;
for (size_t i = 2; i < s->ninputs(); ++i)
// TODO: ensure that loads/stores are either alloca or global, never both
// TODO: ensure that loads/stores have same width and alignment and geps can be merged -
// otherwise slice? create request
auto req_outs = local_mem_req_op::create(*mem_outs[1], load_addrs, store_operands);

// remove alloca from memstate merge
// TODO: handle general case of other nodes getting state edge without a merge
JLM_ASSERT(node->output(1)->nusers() == 1);
auto merge_in = *node->output(1)->begin();
auto merge_node = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*merge_in);
if (dynamic_cast<const llvm::MemoryStateMergeOperation *>(&merge_node->GetOperation()))
{
states.push_back(s->input(i)->origin());
// merge after alloca -> remove merge
JLM_ASSERT(merge_node->ninputs() == 2);
auto other_index = merge_in->index() ? 0 : 1;
merge_node->output(0)->divert_users(merge_node->input(other_index)->origin());
jlm::rvsdg::remove(merge_node);
}
auto store_outs = local_store_op::create(*index, *s->input(1)->origin(), states);
auto nn = dynamic_cast<jlm::rvsdg::node_output *>(store_outs[0])->node();
for (size_t i = 0; i < s->noutputs(); ++i)
else
{
s->output(i)->divert_users(nn->output(i));
// TODO: fix this properly by adding a state edge to the LambdaEntryMemState and routing
// it to the region
JLM_ASSERT(false);
}
remove(s);
auto addr = route_request_rhls(node->region(), store_outs[store_outs.size() - 2]);
auto data = route_request_rhls(node->region(), store_outs.back());
store_operands.push_back(addr);
store_operands.push_back(data);
}
// TODO: ensure that loads/stores are either alloca or global, never both
// TODO: ensure that loads/stores have same width and alignment and geps can be merged -
// otherwise slice? create request
auto req_outs = local_mem_req_op::create(*mem_outs[1], load_addrs, store_operands);

// remove alloca from memstate merge
// TODO: handle general case of other nodes getting state edge without a merge
JLM_ASSERT(node->output(1)->nusers() == 1);
auto merge_in = *node->output(1)->begin();
auto merge_node = rvsdg::TryGetOwnerNode<rvsdg::Node>(*merge_in);
if (dynamic_cast<const llvm::MemoryStateMergeOperation *>(&merge_node->GetOperation()))
{
// merge after alloca -> remove merge
JLM_ASSERT(merge_node->ninputs() == 2);
auto other_index = merge_in->index() ? 0 : 1;
merge_node->output(0)->divert_users(merge_node->input(other_index)->origin());
jlm::rvsdg::remove(merge_node);
// TODO: run dne to
// remove loads/stores
// remove geps
// remove alloca pointer users
// remove alloca
}
else
{
// TODO: fix this properly by adding a state edge to the LambdaEntryMemState and routing it
// to the region
JLM_ASSERT(false);
}

// TODO: run dne to
// remove loads/stores
// remove geps
// remove alloca pointer users
// remove alloca
}
}
}
Expand Down
Loading
Loading