diff --git a/cinn/backends/compiler.cc b/cinn/backends/compiler.cc index 12e7f1ff73..ac6536d7e7 100644 --- a/cinn/backends/compiler.cc +++ b/cinn/backends/compiler.cc @@ -129,7 +129,6 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code) using runtime::cuda::CUDAModule; backends::nvrtc::Compiler compiler; - auto ptx = compiler(source_code); CHECK(!ptx.empty()); diff --git a/cinn/common/cas.cc b/cinn/common/cas.cc index e976b748d5..841f453bf7 100644 --- a/cinn/common/cas.cc +++ b/cinn/common/cas.cc @@ -2016,17 +2016,17 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) { }; { - std::vector a_args, b_args; - if (ap) - a_args = ap->operands(); - else - a_args.push_back(a); - if (bp) - b_args = bp->operands(); - else - b_args.push_back(b); - - return reduce_product_div_product(a_args, b_args); + // TODO: fix this + // std::vector a_args, b_args; + // if (ap) + // a_args = ap->operands(); + // else + // a_args.push_back(a); + // if (bp) + // b_args = bp->operands(); + // else + // b_args.push_back(b); + // return reduce_product_div_product(a_args, b_args); } // x / x diff --git a/cinn/common/cas_test.cc b/cinn/common/cas_test.cc index b6cc6adc4f..f76bbd890e 100644 --- a/cinn/common/cas_test.cc +++ b/cinn/common/cas_test.cc @@ -362,7 +362,7 @@ TEST(CAS, SimplifyMinMax) { LOG(INFO) << "p0 " << p0; auto p2 = CasSimplify(p0); LOG(INFO) << "simplified " << p2; - EXPECT_EQ(GetStreamCnt(p2), "cinn_min(7, ((x) / (2)))"); + EXPECT_EQ(GetStreamCnt(p2), "cinn_min(7, (x / 2))"); } { // -(cinn_min(16, 3400-x-1)-1)/2 + x Var x = ir::_Var_::Make("x", Int(32)); diff --git a/cinn/hlir/framework/CMakeLists.txt b/cinn/hlir/framework/CMakeLists.txt index bbc8f8fd4a..0767e6bae1 100755 --- a/cinn/hlir/framework/CMakeLists.txt +++ b/cinn/hlir/framework/CMakeLists.txt @@ -16,6 +16,7 @@ gather_srcs(cinnapi_src SRCS op_lowering.cc accuracy_checker.cc visualize_helper.cc + op_lowering_util.cc ) if(WITH_CUDA) diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h old mode 100755 new mode 100644 index 625a380e6d..050f36a4ab --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -114,6 +114,14 @@ class Graph : public cinn::common::Graph { } } + std::unordered_set NodeSet() { + std::unordered_set node_set; + for (auto node : CollectNodes()) { + node_set.insert(node); + } + return node_set; + } + std::unordered_set GetInputNodeDatas(); std::unordered_set GetOutputNodeDatas(); diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index f12a3bde53..d5195b2b1f 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -14,6 +14,7 @@ #include "cinn/hlir/framework/op_lowering.h" +#include "cinn/hlir/framework/op_lowering_util.h" #include "cinn/hlir/op/external_api_registry.h" #include "cinn/optim/transform_gpu_forloop.h" @@ -41,43 +42,6 @@ using Comparator = Graph::Group::SharedGroupComparator; using Hasher = Graph::Group::SharedGroupHasher; using cinn::hlir::op::ExternalApiRegistry; -NodeData* GetNodeData(const Node* node) { - auto node_data = (*node->outlinks().begin())->sink()->safe_as(); - CHECK(node_data); - return node_data; -} - -std::vector GetAllNodeData(const Node* node) { - std::vector node_datas; - for (auto& link : node->outlinks_in_order(true)) { - auto node_data = link->sink()->safe_as(); - CHECK(node_data); - node_datas.push_back(node_data); - } - - return node_datas; -} - -std::vector GetConsumer(Node* node) { - std::vector consumers; - auto node_data = GetNodeData(node); - for (auto& link : node_data->outlinks()) { - auto consumer_node = link->sink()->safe_as(); - CHECK(consumer_node); - consumers.push_back(consumer_node); - } - return consumers; -} - -bool IsConstOp(const framework::Node* node) { - static std::unordered_set const_op_type = {"const_scalar", "fill_constant", "arange"}; - if (const_op_type.count(node->op()->name)) { - return true; - } else { - return false; - } -} - OpLowerer::OpLowerer(const absl::flat_hash_map& type_dict, const absl::flat_hash_map& shape_dict, const Target& target) @@ -101,7 +65,7 @@ std::vector OpLowerer::LowerWithoutSchedule(GroupPtr& group) { LOG(FATAL) << "Group Pattern Kind kNonFusible Is Not Implemented!"; } } else { - LOG(FATAL) << "Previous IR Schedule Is Not Implemented!"; + LOG(FATAL) << "Previous IR Schedule Is Unsupport Now, please set FLAGS_cinn_ir_schedule=1 to use new IR Schedule!"; } } @@ -114,9 +78,9 @@ std::vector OpLowerer::Lower(GroupPtr& group) { case framework::kElementWise: case framework::kBroadcast: case framework::kInjective: - return IRLowerOp(&OpLowerer::IRElementwiseCompute, &OpLowerer::IRElementwiseSchedule, group); + return IRLowerOp(&OpLowerer::IRElementwiseCompute, group); case framework::kReduction: - return IRLowerOp(&OpLowerer::IRReduceCompute, &OpLowerer::IRReduceSchedule, group); + return IRLowerOp(&OpLowerer::IRReduceCompute, group); case framework::kOutFusible: LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!"; case framework::kNonFusible: @@ -125,20 +89,7 @@ std::vector OpLowerer::Lower(GroupPtr& group) { LOG(FATAL) << "Group Pattern Kind Is Unknown!"; } } else { - switch (group->op_pattern_kind) { - case framework::kElementWise: - case framework::kBroadcast: - case framework::kInjective: - return LowerOp(&OpLowerer::ElementwiseCompute, &OpLowerer::ElementwiseSchedule, group); - case framework::kReduction: - return LowerOp(&OpLowerer::ReduceCompute, &OpLowerer::ReduceSchedule, group); - case framework::kOutFusible: - return LowerOp(&OpLowerer::OutEWiseFusableCompute, &OpLowerer::OutEWiseFusableSchedule, group); - case framework::kNonFusible: - return LowerNonFusibleOp(group); - default: - LOG(FATAL) << "Group Pattern Kind Is Unknown!"; - } + LOG(FATAL) << "Previous IR Schedule Is Unsupport Now, please set FLAGS_cinn_ir_schedule=1 to use new IR Schedule!"; } } @@ -225,9 +176,7 @@ std::vector OpLowerer::IRLowerOpWithoutSchedule(IRComputeFuncti return {func}; } -std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, - IRScheduleFunction schedule, - GroupPtr& group) { +std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, GroupPtr& group) { poly::StageMap stages; std::vector arg_tensors; std::unordered_map tensor_map; @@ -250,6 +199,7 @@ std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, Node* second = nullptr; // do schedule. VLOG(3) << "Before IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + /* if (group->fused_sub_groups.size() == 0) { (this->*schedule)(ir_sch, tensor_map, group, group, first, second); } else { @@ -258,6 +208,8 @@ std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, (this->*schedule)(ir_sch, tensor_map, group, group->fused_sub_groups[idx], first, second); } } + */ + IRSchedule(ir_sch, group, tensor_map); VLOG(3) << "After IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); // function args group->input_names.clear(); @@ -303,152 +255,6 @@ std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, return {func}; } -// fusion op lowering -std::vector OpLowerer::LowerOp(ComputeFunction compute, ScheduleFunction schedule, GroupPtr& group) { - poly::StageMap stages; - std::vector func_args; - std::unordered_map tensor_map; - VLOG(3) << "Fused Sub-Graph Size Is : " << group->fused_sub_groups.size(); - // do compute. - if (group->fused_sub_groups.size() == 0) { - (this->*compute)(stages, func_args, tensor_map, group, group); - } else { - for (auto& sub_group : group->fused_sub_groups) { - (this->*compute)(stages, func_args, tensor_map, group, sub_group); - } - } - - VLOG(3) << "After Compute, Do Schedule!"; - // do schedule. - if (group->fused_sub_groups.size() == 0) { - (this->*schedule)(stages, tensor_map, group, group); - } else { - for (auto& sub_group : group->fused_sub_groups) { - (this->*schedule)(stages, tensor_map, group, sub_group); - } - } - group->input_names.clear(); - for (auto& args : func_args) { - // input node data name. - group->input_names.push_back(args->name); - } - - group->output_names.clear(); - for (auto& node : group->output_nodes) { - // output node data name. - for (auto node_data : GetAllNodeData(node)) { - group->output_names.push_back(node_data->id()); - } - // collect all output tensor. - std::string post = ""; - std::string prefix = GetNodeData(node)->id(); - for (int idx = 0;; ++idx) { - if (idx == 0) { - CHECK(tensor_map.count(prefix)) << "Can't find output tensor " << prefix; - } - if (!tensor_map.count(prefix + post)) { - break; - } - auto tensor = tensor_map[prefix + post]; - // if tensor is with buffer, it's not a output. - if (!tensor->buffer.defined() && !stages[tensor]->inlined()) { - func_args.push_back(tensor); - } - // update post - post = "_" + std::to_string(idx); - } - } - - return lang::LowerVec(group->GetFuncName(), stages, func_args, {}, {}, nullptr, this->target_); -} - -std::vector OpLowerer::CollectInputTensor(std::vector& func_args, - std::unordered_map& tensor_map, - const Node* node) { - std::vector tensor_inputs; - // get all input nodes - for (auto& link : node->inlinks_in_order(true)) { - auto source = link->source(); - CHECK(source); - auto source_data = source->safe_as(); - CHECK(source_data); - if (FLAGS_cinn_ir_schedule) { - auto dtype = this->type_dict_.at(source_data->id()); - CHECK(dtype.is_supported()) << "Node " << source_data->id() << " 's dtype " << dtype << "is not supported yet!"; - ir::Tensor tensor; - if (dtype.is_float(32)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_float(64)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_float(16)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_bool()) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_int(8)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_int(16)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_int(32)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_int(64)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_uint(8)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_uint(16)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_uint(32)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_uint(64)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } - if (!tensor_map.count(source_data->id())) { - tensor_map[source_data->id()] = tensor; - // record func input args - func_args.push_back(tensor); - } - tensor_inputs.push_back(tensor); - } else { - if (tensor_map.count(source_data->id())) { - tensor_inputs.push_back(tensor_map[source_data->id()]); - } else { - auto dtype = this->type_dict_.at(source_data->id()); - CHECK(dtype.is_supported()) << "Node " << source_data->id() << " 's dtype " << dtype << "is not supported yet!"; - ir::Tensor tensor; - if (dtype.is_float(32)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_float(64)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_float(16)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_bool()) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_int(8)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_int(16)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_int(32)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_int(64)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_uint(8)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_uint(16)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_uint(32)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } else if (dtype.is_uint(64)) { - tensor = lang::Placeholder(source_data->id(), this->shape_dict_.at(source_data->id())); - } - tensor_map[source_data->id()] = tensor; - tensor_inputs.push_back(tensor); - // record func input args - func_args.push_back(tensor); - } - } - } - return tensor_inputs; -} - std::vector OpLowerer::IRElementwiseCompute(poly::StageMap& stages, std::vector& func_tensors, std::unordered_map& tensor_map, @@ -463,7 +269,8 @@ std::vector OpLowerer::IRElementwiseCompute(poly::StageMap& stages, auto node_data = GetNodeData(node); CHECK_EQ(GetAllNodeData(node).size(), 1U); std::vector cinn_inputs; - std::vector tensor_inputs = std::move(CollectInputTensor(func_tensors, tensor_map, node)); + std::vector tensor_inputs = + std::move(CollectInputTensor(node, func_tensors, tensor_map, this->type_dict_, this->shape_dict_)); for (auto& tensor : tensor_inputs) { cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); } @@ -512,49 +319,6 @@ std::vector OpLowerer::IRElementwiseCompute(poly::StageMap& stages, return ast_exprs; } -void OpLowerer::IRElementwiseSchedule(ir::IRSchedule& ir_sch, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group, - Node*&, - Node*&) { - VLOG(2) << "IRElementwiseSchedule Group : " << sub_group->group_id; - auto master_node = *group->master_nodes.begin(); - auto manster_tensor = tensor_map[GetNodeData(master_node)->id()]; - - for (int idx = sub_group->nodes.size() - 1; idx >= 0; --idx) { - auto node = sub_group->nodes[idx]; - auto node_tensor = tensor_map[GetNodeData(node)->id()]; - - VLOG(3) << "Schedule node -> " << node->id() << " var : " << node_tensor->name; - if (group->master_nodes.count(node)) { - continue; - } - - if (IsConstOp(node) && !group->output_nodes.count(node)) { - ir_sch.ComputeInline(ir_sch.GetBlock(node_tensor->name)); - continue; - } - - // if node is fringe node or internal node, fringe node is output node of sub-graph - if (group->output_nodes.count(node) || group->internal_nodes.count(node) || sub_group->internal_nodes.count(node)) { - // internal node use buffer - if (!group->output_nodes.count(node)) { - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SetBuffer(node_block, "local", true); - } - - auto node_block = ir_sch.GetBlock(node_tensor->name); - auto master_loops = ir_sch.GetLoops(manster_tensor->name); - ir_sch.SimpleComputeAt(node_block, master_loops.back()); - continue; - } - - // others elemenwise internal node use compute-inline - ir_sch.ComputeInline(ir_sch.GetBlock(node_tensor->name)); - } -} - std::vector OpLowerer::IRReduceCompute(poly::StageMap& stages, std::vector& func_args, std::unordered_map& tensor_map, @@ -571,7 +335,8 @@ std::vector OpLowerer::IRReduceCompute(poly::StageMap& stages, VLOG(3) << "In ReduceCompute, process node: " << node->id() << " with op type: " << node->op()->name; std::vector cinn_inputs; - std::vector tensor_inputs = std::move(CollectInputTensor(func_args, tensor_map, node)); + std::vector tensor_inputs = + std::move(CollectInputTensor(node, func_args, tensor_map, this->type_dict_, this->shape_dict_)); for (auto& tensor : tensor_inputs) { cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); } @@ -633,645 +398,6 @@ std::vector OpLowerer::IRReduceCompute(poly::StageMap& stages, return ast_exprs; } -void OpLowerer::IRReduceSchedule(ir::IRSchedule& ir_sch, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group, - Node*& master, - Node*& reducer) { - auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - auto OrderAssignReduce = [this](ir::IRSchedule& ir_sch, - const std::string& block_name, - const std::vector& axes, - const bool just_reorder = false) { - // reorder none-last reduce axis to last. - // like: shape = [16,16,16,16,16],axes = [1,3] -> new order = [0, 2, 4, 1, 3]. - std::vector order; - int n_out_dims = ir_sch.GetLoops(block_name).size(); - for (int idx = 0; idx < n_out_dims; ++idx) { - if (std::find(axes.begin(), axes.end(), idx) == axes.end()) { - order.push_back(idx); - } - } - for (auto axis : axes) { - order.push_back(axis); - } - ir_sch.Reorder(ir_sch.GetBlock(block_name), order); - - if (just_reorder) { - return; - } - // fuse others none-reduce axis. - int last_dimension_num = n_out_dims - axes.back() - 1; - int index = n_out_dims - last_dimension_num - axes.size(); - - // fuse last_dimension_num - 1 times - for (auto idx = index; idx < index + last_dimension_num - 1; ++idx) { - ir_sch.Fuse(block_name, {index, index + 1}); - } - - auto loops = ir_sch.GetLoops(block_name); - auto psize = ir::GetLoopExtent(loops[index]); - if (psize > this->target_.max_num_threads()) { - for (int idx = this->target_.max_num_threads(); idx > 0; --idx) { - if (psize % idx == 0) { - ir_sch.Split(loops[index], {-1, idx}); - break; - } - CHECK_GT(idx, 1); - } - } - - // fuse index - 1 times - for (int idx = 0; idx < index - 1; ++idx) { - ir_sch.Fuse(block_name, {0, 1}); - } - }; - - auto WithoutLastDimInReduce = [](const std::vector& inshape, std::vector& axes) { - // if last axis is in reduce. - axes = axes.empty() ? inshape : axes; - if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || - std::find(axes.begin(), axes.end(), -1) != axes.end()) { - return false; - } - - int sum_last_axes = 1; - for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { - sum_last_axes *= inshape[idx]; - } - - if (sum_last_axes > 1) { - return true; - } else { - return false; - } - }; - - auto ScheduleAssignReduceWithoutLast = [this, OrderAssignReduce](ir::IRSchedule& ir_sch, - const std::string& block_name, - const std::vector& inshape, - std::vector& axes) { - axes = axes.empty() ? inshape : axes; - int lane = 1; - int max_num_threads = this->target_.max_num_threads(); - for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { - lane *= inshape[idx]; - } - CHECK_LE(lane, max_num_threads / 2) << "Parallel threads must less equal max_num_threads/2 on gpu!"; - int pos = 0; - int index = axes.size() - 1; - for (; index >= 0; --index) { - if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { - pos = axes[index + 1]; - break; - } - - lane *= inshape[axes[index]]; - if (lane > max_num_threads / 2) { - pos = axes[index]; - break; - } - - if (index == 0) { - pos = axes[0]; - } - } - - if (lane > max_num_threads / 2) { - int prefix = inshape[axes[index]]; - int tail = lane / prefix; - for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; --idx) { - if (prefix % idx == 0) { - ir_sch.Split(block_name, axes[index], {-1, idx}); - break; - } - CHECK_GT(idx - 1, (max_num_threads / 2) / tail) << "idx should greater than (max_num_threads / 2) / tail."; - } - } - - // insert 1 - for (int idx = 0; idx < axes.size() - 1 - index; ++idx) { - auto loops = ir_sch.GetLoops(block_name); - ir_sch.Split(block_name, pos, {-1, ir::GetLoopExtent(loops[pos])}); - } - OrderAssignReduce(ir_sch, block_name, axes); - // return insert 1 - int start_index = ir_sch.GetLoops(block_name).size() - axes.size(); - for (int idx = 0; idx < axes.size(); ++idx) { - auto loops = ir_sch.GetLoops(block_name); - if (ir::GetLoopExtent(loops[start_index]) == 1) { - ir_sch.Fuse({loops[start_index - 1], loops[start_index]}); - } else { - ++start_index; - } - } - }; - - auto ScheduleAssignReduceWithLast = [this, OrderAssignReduce](ir::IRSchedule& ir_sch, - const std::string& block_name, - const std::vector& inshape, - std::vector& axes) { - // find first reduce and second reduce axis. - axes = axes.empty() ? inshape : axes; - int lane = 1; - int index = static_cast(axes.size()) - 1; - auto max_num_threads = this->target_.max_num_threads(); - for (; index >= 0; --index) { - if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { - break; - } - lane *= inshape[axes[index]]; - if (index == 0 && lane <= max_num_threads) { - LOG(FATAL) << "Error! lane is less equal than max_num_threads, Please check!"; - } - if (lane >= max_num_threads / 2) { - if (lane <= max_num_threads) { - --index; - } - break; - } - } - std::vector first_axes(axes.begin(), axes.begin() + index + 1); - if (lane > max_num_threads) { - // last reduce axis size > 1024 - if (index == static_cast(axes.size()) - 1) { - int idx = max_num_threads; - do { - if (lane % idx == 0) { - ir_sch.Split(block_name, axes[index], {-1, idx}); - break; - } - --idx; - } while (idx >= max_num_threads / 2); - // if can't be divide by(1024, 512), it's shouldn't be fused. - CHECK_GE(idx, max_num_threads / 2) << "Check bounds exist, can't fuse!"; - } else { - int axis = axes[index]; - int prefix = inshape[axis]; - int tail = lane / prefix; - for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; --idx) { - if (prefix % idx == 0) { - ir_sch.Split(block_name, axis, {-1, idx}); - break; - } - CHECK_GT(idx, (max_num_threads / 2) / tail) << "Error, it's shouldn't fuse!"; - } - } - OrderAssignReduce(ir_sch, block_name, first_axes); - } else { - int fuse_times = axes.size() - (index + 1) - 1; - for (int idx = 0; idx < fuse_times; ++idx) { - ir_sch.Fuse(block_name, {axes[index + 1], axes[index + 1] + 1}); - } - OrderAssignReduce(ir_sch, block_name, first_axes, true); - // fuse axis before reduce to bind blockidx. - for (int idx = 0; idx < (inshape.size() - axes.size()) - 1; ++idx) { - ir_sch.Fuse(block_name, {0, 1}); - } - } - }; - - if (master == nullptr && reducer == nullptr) { - auto blocks = ir_sch.GetAllBlocks(); - for (int idx = blocks.size() - 1; idx >= 0; --idx) { - auto block = blocks[idx]; - CHECK(block->as()); - CHECK(block->as()->schedule_block->as()); - if (!tensor_map.count(block->as()->schedule_block->as()->name)) { - continue; - } - - for (auto node : group->master_nodes) { - if (GetNodeData(node)->id() == - block->as()->schedule_block->as()->name) { - if (op_pattern_dict[node->op()] != framework::kReduction) { - master = node; - break; - } - - if (op_pattern_dict[node->op()] == framework::kReduction && master) { - reducer = node; - break; - } - } - } - - if (master && reducer) { - break; - } - } - CHECK((master && reducer) || (!master && !reducer)) << "Can't find Master reducer!"; - if (!master && !reducer) { - master = *group->master_nodes.begin(); - reducer = *group->master_nodes.begin(); - } - - // do master schedule. - if (op_pattern_dict[master->op()] != framework::kReduction) { - VLOG(2) << "Do Master Schedule : " << master->id(); - auto master_data = GetNodeData(master); - CHECK(master_data); - CHECK(tensor_map.count(master_data->id())); - auto master_tensor = tensor_map[master_data->id()]; - auto loops = ir_sch.GetLoops(master_tensor->name); - if (op_pattern_dict[master->op()] == framework::kElementWise) { - ir_sch.FlattenLoops(loops, true); - } else { - ir_sch.FlattenLoops(loops, false); - } - - auto reducer_data = GetNodeData(reducer); - auto reducer_tensor = tensor_map[reducer_data->id()]; - auto rloops = ir_sch.GetLoops(reducer_tensor->name); - - // assign master loops to reducer loops without reduce axis. - int extend = 1; - std::vector factors; - auto sloops = ir_sch.GetLoops(master_tensor->name); - for (auto& loop : rloops) { - // without last reduce axis, so check loop extend. - extend *= loop.As()->extent.as_int32(); - if (extend > sloops.back().As()->extent.as_int32()) { - break; - } - CHECK_LE(extend, sloops.back().As()->extent.as_int32()); - factors.push_back(loop.As()->extent.as_int32()); - } - ir_sch.Split(sloops.back(), factors); - - auto nloops = ir_sch.GetLoops(master_tensor->name); - CHECK_GE(rloops.size(), nloops.size()); - for (int idx = 0; idx < nloops.size(); ++idx) { - nloops[idx].As()->set_bind_info(rloops[idx].As()->bind_info()); - } - } - // do reducer schedule. - { - auto reducer_data = GetNodeData(reducer); - auto reducer_tensor = tensor_map[reducer_data->id()]; - CHECK(reducer->attrs.attr_store.count("dim")); - auto reducer_axes = absl::get>(reducer->attrs.attr_store.at("dim")); - CHECK(reducer->inlinks_in_order().size()); - CHECK(this->shape_dict_.count(reducer->inlinks_in_order()[0]->source()->id())); - auto reducer_shape = this->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); - - if (reducer_axes.empty()) { - for (int i = 0; i < reducer_shape.size(); ++i) { - reducer_axes.emplace_back(i); - } - } - - bool without_last_dim = WithoutLastDimInReduce(reducer_shape, reducer_axes); - - std::unordered_set visited_nodes; - for (auto node : group->master_nodes) { - VLOG(2) << "Schedule reduce node -> " << node->id(); - if (op_pattern_dict[node->op()] != framework::kReduction) { - continue; - } - auto node_data = GetNodeData(node); - auto node_tensor = tensor_map[node_data->id()]; - - if (!group->output_nodes.count(node)) { - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SetBuffer(node_block, "local", true); - } - if (node == reducer) { - continue; - } - auto node_shape = this->shape_dict_.at(node->inlinks_in_order()[0]->source()->id()); - if (without_last_dim) { - VLOG(2) << "Reduce Schedule WithoutLastDimInReduce"; - // find a shape to do simple compute at. - auto tmp_reducer = reducer; - auto tmp_reducer_shape = reducer_shape; - if (node_shape != reducer_shape) { - // try to find the same shape reduce from visited_nodes - for (auto visited : visited_nodes) { - auto shape = this->shape_dict_.at(visited->inlinks_in_order()[0]->source()->id()); - if (shape == node_shape) { - tmp_reducer = visited; - tmp_reducer_shape = shape; - break; - } - } - } - visited_nodes.insert(node); - auto tmp_reducer_data = GetNodeData(tmp_reducer); - auto tmp_reducer_tensor = tensor_map[tmp_reducer_data->id()]; - - // using block shuffle reduce. - if (tensor_map.count(reducer_data->id() + "_1")) { - auto node_0_tensor = tensor_map[node_data->id() + "_0"]; - auto node_0_block = ir_sch.GetBlock(node_0_tensor->name); - - auto tmp_reducer_0_tensor = tensor_map[tmp_reducer_data->id() + "_0"]; - auto tmp_reducer_0_loops = ir_sch.GetLoops(tmp_reducer_0_tensor->name); - - if (tmp_reducer_shape == node_shape) { - ir_sch.SimpleComputeAt(node_0_block, tmp_reducer_0_loops.back()); - // init compute at reduce - int loop_depth = ir_sch.GetLoops(node_0_tensor->name + "__reduce_init").size(); - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_0_tensor->name + "__reduce_init"), - ir_sch.GetLoops(node_0_tensor->name)[loop_depth - 1]); - } else { - if (tmp_reducer_0_tensor->shape.back() == node_0_tensor->shape.back()) { - int num_reduce_axis = tmp_reducer_0_tensor->reduce_axis.size(); - CHECK_GE(static_cast(tmp_reducer_0_loops.size()) - num_reduce_axis - 1, 0); - ir_sch.SimpleComputeAt(node_0_block, - tmp_reducer_0_loops[tmp_reducer_0_loops.size() - num_reduce_axis - 1]); - // init compute at reduce - int loop_depth = ir_sch.GetLoops(node_0_tensor->name + "__reduce_init").size(); - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_0_tensor->name + "__reduce_init"), - ir_sch.GetLoops(node_0_tensor->name)[loop_depth - 1]); - } else { - CHECK_GE(static_cast(tmp_reducer_0_loops.size()), 2); - ir_sch.SimpleComputeAt(node_0_block, tmp_reducer_0_loops[0]); - } - } - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_tensor->name), - ir_sch.GetLoops(tmp_reducer_tensor->name).back()); - } else { - if (tmp_reducer_shape == node_shape) { - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_tensor->name), - ir_sch.GetLoops(tmp_reducer_tensor->name).back()); - } else { - int num_reduce_axis = tmp_reducer_tensor->reduce_axis.size(); - auto tmp_reducer_loops = ir_sch.GetLoops(tmp_reducer_tensor->name); - CHECK_GE(static_cast(tmp_reducer_loops.size()) - num_reduce_axis - 1, 0); - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_tensor->name), - tmp_reducer_loops[tmp_reducer_loops.size() - num_reduce_axis - 1]); - } - // init compute at reduce - int loop_depth = ir_sch.GetLoops(node_tensor->name + "__reduce_init").size(); - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_tensor->name + "__reduce_init"), - ir_sch.GetLoops(node_tensor->name)[loop_depth - 1]); - } - } else { - VLOG(2) << "Reduce Schedule WithLastDimInReduce"; - // if with column reduce behind. - if (tensor_map.count(node_data->id() + "_1")) { - auto reducer_1_tensor = tensor_map[reducer_data->id() + "_1"]; - auto reducer_0_tensor = tensor_map[reducer_data->id() + "_0"]; - - auto node_1_tensor = tensor_map[node_data->id() + "_1"]; - auto node_0_tensor = tensor_map[node_data->id() + "_0"]; - - auto node_block_1 = ir_sch.GetBlock(node_1_tensor->name); - auto node_block_0 = ir_sch.GetBlock(node_0_tensor->name); - auto node_block = ir_sch.GetBlock(node_tensor->name); - - ir_sch.SimpleComputeAt(node_block, ir_sch.GetLoops(reducer_tensor->name).back()); - ir_sch.SimpleComputeAt(node_block_0, ir_sch.GetLoops(reducer_0_tensor->name).back()); - ir_sch.SimpleComputeAt(node_block_1, ir_sch.GetLoops(reducer_1_tensor->name).back()); - // init compute at reduce - int loop_depth = ir_sch.GetLoops(node_1_tensor->name + "__reduce_init").size(); - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_1_tensor->name + "__reduce_init"), - ir_sch.GetLoops(node_1_tensor->name)[loop_depth - 1]); - } else if (tensor_map.count(node_data->id() + "_0")) { - auto reducer_0_tensor = tensor_map[reducer_data->id() + "_0"]; - auto node_0_tensor = tensor_map[node_data->id() + "_0"]; - - auto node_0_block = ir_sch.GetBlock(node_0_tensor->name); - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SimpleComputeAt(node_block, ir_sch.GetLoops(reducer_tensor->name).back()); - ir_sch.SimpleComputeAt(node_0_block, ir_sch.GetLoops(reducer_0_tensor->name).back()); - } else { - LOG(FATAL) << "Error! Unkown Reduce Type, Please Check!"; - } - } - } - - if (without_last_dim) { - if (tensor_map.count(reducer_data->id() + "_1")) { - auto reducer_tensor = tensor_map[GetNodeData(reducer)->id()]; - auto reducer_loops = ir_sch.GetLoops(reducer_tensor->name); - ir_sch.SyncThreads(reducer_loops[0], false); - } - } - } - } - - // master node - auto master_data = GetNodeData(master); - CHECK(master_data); - CHECK(tensor_map.count(master_data->id())); - auto master_tensor = tensor_map[master_data->id()]; - auto master_shape = this->shape_dict_.at(master_data->id()); - auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); - - // reducer node - auto reducer_data = GetNodeData(reducer); - CHECK(reducer_data); - CHECK(reducer->inlinks_in_order().size()); - CHECK(this->shape_dict_.count(reducer->inlinks_in_order()[0]->source()->id())); - auto reducer_shape = this->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); - auto reduce_size = std::accumulate(reducer_shape.begin(), reducer_shape.end(), 1, std::multiplies()); - - CHECK(reducer->attrs.attr_store.count("dim")); - auto reducer_axes = absl::get>(reducer->attrs.attr_store.at("dim")); - if (reducer_axes.empty()) { - for (int i = 0; i < reducer_shape.size(); ++i) { - reducer_axes.emplace_back(i); - } - } - - VLOG(2) << "master node : " << master->id() << " ,reducer node : " << reducer->id(); - for (int idx = sub_group->nodes.size() - 1; idx >= 0; --idx) { - auto node = sub_group->nodes[idx]; - - if (node == master) { - continue; - } - if (op_pattern_dict[node->op()] == framework::kReduction) { - continue; - } - auto node_data = GetNodeData(node); - auto node_tensor = tensor_map[node_data->id()]; - - VLOG(3) << "Schedule node -> " << node->id() << " var : " << node_tensor->name; - // for x86 schedule. - if (this->target_ == common::DefaultHostTarget()) { - LOG(FATAL) << "X86 Not implemented"; - } - - bool dont_compute_inline = - group->output_nodes.count(node) || group->internal_nodes.count(node) || sub_group->internal_nodes.count(node); - if (!dont_compute_inline) { - auto consumers = GetConsumer(node); - for (auto& consumer : consumers) { - if (op_pattern_dict[consumer->op()] == framework::kReduction) { - dont_compute_inline = true; - break; - } - } - } - - // if is const op, do compute inline. - if (IsConstOp(node) && !group->output_nodes.count(node)) { - dont_compute_inline = false; - } - - // if node is internal node or output, try to copy schedule from fellow node - if (dont_compute_inline) { - VLOG(2) << "Reduce Schedule for Elementwise Type"; - // if node is not output node, set buffer. - if (!group->output_nodes.count(node)) { - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SetBuffer(node_block, "local", true); - } - // node is after reduce - auto node_shape = this->shape_dict_.at(node_data->id()); - auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); - if (node_shape == master_shape || node_size == master_size) { - VLOG(2) << "Do Elementwise Type After Reduce!"; - auto loops = ir_sch.GetLoops(node_tensor->name); - // flat loop and tensor shape - if (op_pattern_dict[master->op()] == framework::kElementWise) { - ir_sch.FlattenLoops(loops, true); - } else { - ir_sch.FlattenLoops(loops, false); - } - // split loop to assign master loop - std::vector factors; - auto mloops = ir_sch.GetLoops(master_tensor->name); - for (auto& loop : mloops) { - factors.push_back(loop.As()->extent.as_int32()); - } - loops = ir_sch.GetLoops(node_tensor->name); - ir_sch.Split(loops.back(), factors); - // note do simple compute at - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SimpleComputeAt(node_block, mloops.back()); - continue; - } - // do elementwise flat - auto loops = ir_sch.GetLoops(node_tensor->name); - if (op_pattern_dict[node->op()] == framework::kElementWise) { - ir_sch.FlattenLoops(loops, true); - } else { - ir_sch.FlattenLoops(loops, false); - } - // node is before reduce. - if (WithoutLastDimInReduce(reducer_shape, reducer_axes)) { - VLOG(2) << "Reduce Schedule for WithoutLastDimInReduce"; - // find a shape to do simple compute at. - auto tmp_reducer = reducer; - auto tmp_reducer_shape = reducer_shape; - auto tmp_reducer_size = std::accumulate(reducer_shape.begin(), reducer_shape.end(), 1, std::multiplies()); - // node shape. - auto node_shape = this->shape_dict_.at(node_data->id()); - if (node_shape != tmp_reducer_shape && node_size != reduce_size) { - // try to find the same shape reduce from visited_nodes - for (auto rnode : group->master_nodes) { - if (op_pattern_dict[rnode->op()] != framework::kReduction) { - continue; - } - auto shape = this->shape_dict_.at(rnode->inlinks_in_order()[0]->source()->id()); - auto size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - if (shape == node_shape || size == node_size) { - tmp_reducer = rnode; - tmp_reducer_size = size; - tmp_reducer_shape = shape; - break; - } - } - } - // do split - CHECK(node_shape == tmp_reducer_shape || node_size == tmp_reducer_size); - - auto loops = ir_sch.GetLoops(node_tensor->name); - ir_sch.Split(loops.back(), tmp_reducer_shape); - - auto tmp_reducer_data = GetNodeData(tmp_reducer); - auto tmp_reducer_tensor = tensor_map[tmp_reducer_data->id()]; - // if used block shuffle reduce - if (tensor_map.count(tmp_reducer_data->id() + "_1")) { - ScheduleAssignReduceWithoutLast(ir_sch, node_tensor->name, tmp_reducer_shape, reducer_axes); - auto tmp_reducer_tensor_0 = tensor_map[tmp_reducer_data->id() + "_0"]; - auto tmp_reducer_loops_0 = ir_sch.GetLoops(tmp_reducer_tensor_0->name); - auto node_loops = ir_sch.GetLoops(node_tensor->name); - if (node_loops.size() < tmp_reducer_loops_0.size()) { - ir_sch.Split(node_tensor->name, 0, {-1, ir::GetLoopExtent(node_loops[0])}); - } - CHECK_EQ(ir_sch.GetLoops(node_tensor->name).size(), tmp_reducer_loops_0.size()) - << "node loops and reduce loops must be equal!"; - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SimpleComputeAt(node_block, tmp_reducer_loops_0.back()); - } else { - OrderAssignReduce(ir_sch, node_tensor->name, reducer_axes); - - auto node_block = ir_sch.GetBlock(node_tensor->name); - auto node_loops = ir_sch.GetLoops(node_tensor->name); - if (node_loops.size() < ir_sch.GetLoops(tmp_reducer_tensor->name).size()) { - ir_sch.Split(node_tensor->name, 0, {-1, ir::GetLoopExtent(node_loops[0])}); - } - CHECK_EQ(ir_sch.GetLoops(node_tensor->name).size(), ir_sch.GetLoops(tmp_reducer_tensor->name).size()) - << "node loop size and reduce loop size must be equal!"; - ir_sch.SimpleComputeAt(node_block, ir_sch.GetLoops(tmp_reducer_tensor->name).back()); - } - } else { - VLOG(2) << "Reduce Schedule for WithLastDimInReduce"; - if (tensor_map.count(reducer_data->id() + "_1")) { - { - auto node_loops = ir_sch.GetLoops(node_tensor->name); - ir_sch.Split(node_loops.back(), reducer_shape); - } - - ScheduleAssignReduceWithLast(ir_sch, node_tensor->name, reducer_shape, reducer_axes); - auto reducer_1_tensor = tensor_map[reducer_data->id() + "_1"]; - auto reducer_1_block = ir_sch.GetBlock(reducer_1_tensor->name); - auto reducer_1_loops = ir_sch.GetLoops(reducer_1_block); - - auto node_loops = ir_sch.GetLoops(node_tensor->name); - if (ir_sch.GetLoops(node_tensor->name).size() < ir_sch.GetLoops(reducer_1_block).size()) { - ir_sch.Split(node_tensor->name, 0, {-1, ir::GetLoopExtent(node_loops[0])}); - } - - CHECK_EQ(ir_sch.GetLoops(node_tensor->name).size(), ir_sch.GetLoops(reducer_1_block).size()) - << "node loop size and reduce loop size must be equal!" << ir_sch.GetModule().GetExprs().at(0); - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SimpleComputeAt(node_block, reducer_1_loops.back()); - } else { - auto reducer_0_tensor = tensor_map[reducer_data->id() + "_0"]; - auto reducer_0_block = ir_sch.GetBlock(reducer_0_tensor->name); - auto reducer_0_loops = ir_sch.GetLoops(reducer_0_block); - { - auto node_loops = ir_sch.GetLoops(node_tensor->name); - std::vector factors; - for (auto& loop : reducer_0_loops) { - factors.push_back(loop.As()->extent.as_int32()); - } - ir_sch.Split(node_loops.back(), factors); - } - - auto node_loops = ir_sch.GetLoops(node_tensor->name); - if (node_loops.size() < reducer_0_loops.size()) { - ir_sch.Split(node_tensor->name, 0, {-1, ir::GetLoopExtent(node_loops[0])}); - } - CHECK_EQ(ir_sch.GetLoops(node_tensor->name).size(), reducer_0_loops.size()) - << "node loop size and reduce loop size must be equal!" << ir_sch.GetModule().GetExprs().at(0); - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SimpleComputeAt(node_block, reducer_0_loops.back()); - } - } - continue; - } - - // others elemenwise internal node use compute-inline - VLOG(2) << "Do Elementwise ComputeInline!"; - auto loops = ir_sch.GetLoops(node_tensor->name); - if (op_pattern_dict[node->op()] == framework::kElementWise) { - ir_sch.FlattenLoops(loops, true); - } else { - ir_sch.FlattenLoops(loops, false); - } - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.ComputeInline(node_block); - } -} - std::vector OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, bool apply_impl_schedule) { VLOG(3) << "LowerNonFusibleOp Group : " << group->group_id; // get input tensor and output tensor @@ -1286,46 +412,19 @@ std::vector OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, boo std::vector args; std::unordered_map tensor_map; - for (auto& i : node->inlinks_in_order(true)) { - std::string id = i->source()->as()->id(); - auto shape = shape_dict_.at(id); - Type dtype = type_dict_.at(id); - CHECK(dtype.is_supported()) << "Node " << id << " 's dtype " << dtype << "is not supported yet!"; - + for (auto& node_data : GetInputNodeData(node)) { + CHECK(node_data); ir::Tensor tensor; - if (!tensor_map.count(id)) { - if (dtype.is_float(32)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_float(64)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_float(16)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_bool()) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_int(8)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_int(16)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_int(32)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_int(64)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_uint(8)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_uint(16)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_uint(32)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_uint(64)) { - tensor = lang::Placeholder(id, shape); - } - tensor_map[id] = tensor; - // input name - group->input_names.push_back(id); - // input args type + if (!tensor_map.count(node_data->id())) { + tensor = GetTensor(node_data, this->type_dict_, this->shape_dict_); + // record tensor. + tensor_map[node_data->id()] = tensor; + // input name. + group->input_names.push_back(node_data->id()); + // input type. args.emplace_back(tensor->buffer, ir::Argument::IO::kInput); } else { - tensor = tensor_map[id]; + tensor = tensor_map[node_data->id()]; } inputs.push_back(tensor); cinn_inputs.push_back(common::CINNValue(tensor)); @@ -1424,890 +523,68 @@ std::vector OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, boo } } -void OpLowerer::ElementwiseCompute(poly::StageMap& stages, - std::vector& func_args, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group) { - VLOG(3) << "ElementwiseCompute Group : " << sub_group->group_id; - auto& strategy = Operator::GetAttrs("CINNStrategy"); - for (auto& node : sub_group->nodes) { - auto node_data = GetNodeData(node); - std::vector cinn_inputs; - std::vector tensor_inputs = std::move(CollectInputTensor(func_args, tensor_map, node)); - for (auto& tensor : tensor_inputs) { - stages->InsertLazily(tensor); - cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); - } - - std::vector out_types; - std::vector> out_shapes; +// do compute +void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, + const GroupPtr& group, + const std::unordered_map& tensor_map) { + // topological order. + auto nodes_set = group->NodeSet(); + auto v_consumers = BuildVirtualConsumer(group, this->shape_dict_); + auto nodes_in_order = TopologicalOrder(group, v_consumers); + // find reducer. + std::unordered_set nodes_inline; + auto greducer = FindGlobalReducer(nodes_in_order); + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - out_types.push_back(this->type_dict_.at(node_data->id())); - out_shapes.push_back(this->shape_dict_.at(node_data->id())); - - auto impl = - OpStrategy::SelectImpl(strategy[node->op()](node->attrs, tensor_inputs, out_types, out_shapes, this->target_)); - // do compute - common::CINNValuePack value_pack = impl->fcompute(common::CINNValuePack{cinn_inputs}); - - if (group->master_nodes.count(node)) { - // do shedule - value_pack = impl->fschedule(value_pack); - } - - CHECK(value_pack.size() == 2); - Expr out = value_pack[0]; - poly::StageMap tmp_stages = value_pack.back(); - - tensor_map[node_data->id()] = out.as_tensor_ref(); - stages->InsertLazily(out.as_tensor_ref(), tmp_stages[out.as_tensor_ref()]); - } -} - -void OpLowerer::ElementwiseSchedule(poly::StageMap& stages, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group) { - VLOG(3) << "ElementwiseSchedule Group : " << sub_group->group_id; - auto master_node = *group->master_nodes.begin(); - auto master_node_data = GetNodeData(master_node); - auto master_stage = stages[tensor_map[master_node_data->id()]]; - auto master_shape = this->shape_dict_.at(master_node_data->id()); - for (auto& node : sub_group->nodes) { - auto node_data = GetNodeData(node); - auto node_stage = stages[tensor_map[node_data->id()]]; - auto node_shape = this->shape_dict_.at(node_data->id()); - // if group master node - if (group->master_nodes.count(node)) { - continue; - } - - if (master_shape != node_shape) { - CHECK(!group->output_nodes.count(node)) << node->id() << " is to be broadcasted, it can't be output!"; - node_stage->ComputeInline(); - continue; - } - - CHECK(master_shape == node_shape) << "node data shape must be equal to master node!"; - // if node is fringe node or internal node, fringe node is output node of sub-graph - if (group->output_nodes.count(node) || group->internal_nodes.count(node) || sub_group->internal_nodes.count(node)) { - // copy schedule from master node - node_stage->CopyTransform(master_stage); - node_stage->CopyLoopInfo(master_stage); - // internal node use buffer - if (!group->output_nodes.count(node)) { - node_stage->SetBuffer("local"); - } - // compute at master node - node_stage->SimpleComputeAt(master_stage, master_stage->n_out_dims() - 1); - continue; - } - - // others elemenwise internal node use compute-inline - node_stage->ComputeInline(); - } -} - -void OpLowerer::ReduceCompute(poly::StageMap& stages, - std::vector& func_args, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group) { - VLOG(3) << "ReduceCompute Group : " << sub_group->group_id; - auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); - auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - - Node* reducer = nullptr; - for (auto& node : sub_group->nodes) { - auto node_data = GetNodeData(node); - - std::vector cinn_inputs; - std::vector tensor_inputs = std::move(CollectInputTensor(func_args, tensor_map, node)); - VLOG(3) << "ReduceCompute tensor_inputs size is : " << tensor_inputs.size(); - for (auto& tensor : tensor_inputs) { - stages->InsertLazily(tensor); - cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); - } - - std::vector out_types; - std::vector> out_shapes; - - out_types.push_back(this->type_dict_.at(node_data->id())); - out_shapes.push_back(this->shape_dict_.at(node_data->id())); - - auto impl = - OpStrategy::SelectImpl(cinn_strategy[node->op()](node->attrs, tensor_inputs, out_types, out_shapes, target_)); - // do compute - common::CINNValuePack value_pack = impl->fcompute(common::CINNValuePack{cinn_inputs}); - - CHECK_GE(value_pack.size(), 2UL); - CHECK_LE(value_pack.size(), 5UL); - poly::StageMap tmp_stages = value_pack.back(); - - std::string post = ""; - for (int idx = 0; idx < value_pack.size() - 1; ++idx) { - Expr expr = value_pack[idx]; - stages->InsertLazily(expr.as_tensor_ref(), tmp_stages[expr.as_tensor_ref()]); - tensor_map[node_data->id() + post] = expr.as_tensor_ref(); - // As op may has more than 1 output tensor, using id + "_0"/"_1" as key. - post = "_" + std::to_string(idx); - } - value_pack.back() = CINNValue(stages); - - // node is kReduction - if (op_pattern_dict[node->op()] == framework::kReduction) { - reducer = node; - // do schedule - value_pack = impl->fschedule(value_pack); - } else if (group->master_nodes.count(node)) { - Expr out = value_pack[0]; - // node is master node, copy schedule from reduce node - if (reducer) { - auto reducer_data = GetNodeData(reducer); - stages[out.as_tensor_ref()]->CopyTransform(stages[tensor_map[reducer_data->id()]]); - stages[out.as_tensor_ref()]->CopyLoopInfo(stages[tensor_map[reducer_data->id()]]); - } else { - bool copied_transform = false; - for (auto rnode : group->master_nodes) { - if (op_pattern_dict[rnode->op()] == framework::kReduction) { - auto rnode_data = GetNodeData(rnode); - if (!tensor_map.count(rnode_data->id())) { - continue; - } - stages[out.as_tensor_ref()]->CopyTransform(stages[tensor_map[rnode_data->id()]]); - stages[out.as_tensor_ref()]->CopyLoopInfo(stages[tensor_map[rnode_data->id()]]); - copied_transform = true; - break; - } - } - CHECK(copied_transform) << "master node fail to copy transfrom from reduce node!"; - } - } - } -} - -void OpLowerer::ReduceSchedule(poly::StageMap& stages, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group) { - VLOG(3) << "ReduceSchedule Group : " << sub_group->group_id; - auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - // assign reduce input tensor schedule, do loop transform. - auto OrderAssignReduce = [this, &stages]( - poly::Stage* stage, const std::vector& axes, const bool just_reorder = false) { - // reorder none-last reduce axis to last. - // like: shape = [16,16,16,16,16],axes = [1,3] -> new order = [0, 2, 4, 1, 3]. - std::vector order; - int n_out_dims = stage->n_out_dims(); - for (int idx = 0; idx < n_out_dims; ++idx) { - if (std::find(axes.begin(), axes.end(), idx) == axes.end()) { - order.push_back(idx); - } - } - for (auto axis : axes) { - order.push_back(axis); - } - stage->Reorder(order); - - if (just_reorder) { - return; - } - - // fuse others none-reduce axis. - int last_dimension_num = n_out_dims - axes.back() - 1; - int index = n_out_dims - last_dimension_num - axes.size(); - - // fuse last_dimension_num - 1 times - for (auto idx = index; idx < index + last_dimension_num - 1; ++idx) { - stage->Fuse(index, index + 1); - } - - if (stage->GetDimRange(index) > this->target_.max_num_threads()) { - stage->Split(index, this->target_.max_num_threads()); - } - - // fuse index - 1 times - for (int idx = 0; idx < index - 1; ++idx) { - stage->Fuse(0, 1); - } - }; - - auto WithoutLastDimInReduce = [](const std::vector& inshape, std::vector& axes) { - // if last axis is in reduce. - axes = axes.empty() ? inshape : axes; - if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || - std::find(axes.begin(), axes.end(), -1) != axes.end()) { - return false; - } - - int sum_last_axes = 1; - for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { - sum_last_axes *= inshape[idx]; - } - - if (sum_last_axes > 1) { - return true; - } else { - return false; - } - }; - - auto ScheduleAssignReduceWithoutLast = - [this, OrderAssignReduce](poly::Stage* stage, const std::vector& inshape, std::vector& axes) { - axes = axes.empty() ? inshape : axes; - int lane = 1; - int max_num_threads = this->target_.max_num_threads(); - for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { - lane *= inshape[idx]; - } - CHECK_LE(lane, max_num_threads / 2) << "Parallel threads must less equal max_num_threads/2 on gpu!"; - int pos = 0; - int index = axes.size() - 1; - for (; index >= 0; --index) { - if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { - pos = axes[index + 1]; - break; - } - - lane *= inshape[axes[index]]; - if (lane > max_num_threads / 2) { - pos = axes[index]; - break; - } - - if (index == 0) { - pos = axes[0]; - } - } - - if (lane > max_num_threads / 2) { - int prefix = inshape[axes[index]]; - int tail = lane / prefix; - for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; --idx) { - if (prefix % idx == 0) { - stage->Split(axes[index], idx); - break; - } - CHECK_GT(idx - 1, (max_num_threads / 2) / tail) << "idx should greater than (max_num_threads / 2) / tail."; - } - } - - // insert 1 - for (int idx = 0; idx < axes.size() - 1 - index; ++idx) { - stage->Split(pos, stage->GetDimRange(pos)); - } - - OrderAssignReduce(stage, axes); - }; - - auto ScheduleAssignReduceWithLast = [this, OrderAssignReduce]( - poly::Stage* stage, const std::vector& inshape, std::vector& axes) { - // find first reduce and second reduce axis. - axes = axes.empty() ? inshape : axes; - int lane = 1; - int index = static_cast(axes.size()) - 1; - auto max_num_threads = this->target_.max_num_threads(); - for (; index >= 0; --index) { - if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { - break; - } - lane *= inshape[axes[index]]; - if (index == 0 && lane <= max_num_threads) { - LOG(FATAL) << "Error! lane is less equal than max_num_threads, Please check!"; - } - if (lane >= max_num_threads / 2) { - if (lane <= max_num_threads) { - --index; - } - break; - } - } - std::vector first_axes(axes.begin(), axes.begin() + index + 1); - if (lane > max_num_threads) { - // last reduce axis size > 1024 - if (index == static_cast(axes.size()) - 1) { - int idx = max_num_threads; - do { - if (lane % idx == 0) { - stage->Split(axes[index], idx); - break; - } - --idx; - } while (idx >= max_num_threads / 2); - // if can't be divide by(1024, 512), it's shouldn't be fused. - CHECK_GE(idx, max_num_threads / 2) << "Check bounds exist, can't fuse!"; - } else { - int axis = axes[index]; - int prefix = inshape[axis]; - int tail = lane / prefix; - for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; --idx) { - if (prefix % idx == 0) { - stage->Split(axis, idx); - break; - } - CHECK_GT(idx, (max_num_threads / 2) / tail) << "Error, it's shouldn't fuse!"; - } - } - OrderAssignReduce(stage, first_axes); - } else { - int fuse_times = axes.size() - (index + 1) - 1; - for (int idx = 0; idx < fuse_times; ++idx) { - stage->Fuse(axes[index + 1], axes[index + 1] + 1); - } - OrderAssignReduce(stage, first_axes, true); - } - }; - - Node* master_node = nullptr; - for (auto node : group->master_nodes) { - if (op_pattern_dict[node->op()] != framework::kReduction) { - master_node = node; - break; - } - } - - // if not find master node, using last kReduction as master node. - if (!master_node) { - if (group->fused_sub_groups.empty()) { - master_node = group->nodes.back(); - } else { - master_node = group->fused_sub_groups.back()->nodes.back(); - } - CHECK_EQ(op_pattern_dict[master_node->op()], framework::kReduction) << "Master Node Type Must Be Reduce!"; - } - auto master_node_data = GetNodeData(master_node); - auto master_stage = stages[tensor_map[master_node_data->id()]]; - - Node* master_reducer = op_pattern_dict[master_node->op()] == framework::kReduction ? master_node : nullptr; - // find the reducer that link to master node. - if (!master_reducer) { - for (auto reducer : group->master_nodes) { - if (op_pattern_dict[reducer->op()] == framework::kReduction) { - master_reducer = reducer; - break; - } - } - } - CHECK(master_reducer) << "Can't find Master reducer!"; - auto master_reducer_data = GetNodeData(master_reducer); - auto master_reducer_stage = stages[tensor_map[master_reducer_data->id()]]; - CHECK(master_reducer->attrs.attr_store.count("dim")); - auto master_reducer_axes = absl::get>(master_reducer->attrs.attr_store.at("dim")); - CHECK(master_reducer->inlinks_in_order().size()); - CHECK(this->shape_dict_.count(master_reducer->inlinks_in_order()[0]->source()->id())); - auto master_reducer_shape = this->shape_dict_.at(master_reducer->inlinks_in_order()[0]->source()->id()); - - if (master_reducer_axes.empty()) { - for (int i = 0; i < master_reducer_shape.size(); ++i) { - master_reducer_axes.emplace_back(i); - } - } - - bool reduce_with_same_shape = true; - bool without_last_dim = WithoutLastDimInReduce(master_reducer_shape, master_reducer_axes); - if (without_last_dim) { - // check each reduce has same input shape. - for (auto reducer : group->master_nodes) { - if (op_pattern_dict[reducer->op()] != framework::kReduction) { - continue; - } - if (this->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()) != master_reducer_shape) { - reduce_with_same_shape = false; - break; - } - } - } - // update sync thread depend. - for (auto stage : stages) { - if (stage.first.find("syncthreads") != std::string::npos) { - stage.second->CtrlDepend(tensor_map[master_reducer_data->id() + "_0"]); - } - } - - VLOG(3) << "master node : " << master_node->id() << " ,reducer node : " << master_reducer->id(); - for (auto& node : sub_group->nodes) { - VLOG(3) << "Schedule node -> " << node->id(); - auto node_data = GetNodeData(node); - auto stage = stages[tensor_map[node_data->id()]]; - // if node is kReduction - if (node == master_node) { - continue; - } - // for x86 schedule. - if (this->target_ == common::DefaultHostTarget()) { - if (op_pattern_dict[node->op()] == framework::kReduction) { - if (!group->output_nodes.count(node)) { - stage->SetBuffer("local"); - } - if (node == master_reducer) { - stage->SimpleComputeAt(master_stage, master_stage->n_out_dims() - 1); - } else { - stage->SimpleComputeAt(master_reducer_stage, master_reducer_stage->n_out_dims() - 1); - } - continue; - } - - if (group->output_nodes.count(node) || group->internal_nodes.count(node) || - sub_group->internal_nodes.count(node)) { - if (!group->output_nodes.count(node)) { - stage->SetBuffer("local"); - } - if (this->shape_dict_.at(node_data->id()) == this->shape_dict_.at(master_node_data->id())) { - stage->SimpleComputeAt(master_stage, master_stage->n_out_dims() - 1); + // do schedule + for (auto node : nodes_in_order) { + // consumers. + auto consumers = GetConsumersInSet(node, nodes_set); + const Node* reducer = greducer ? FindNearestReducer(node, nodes_set) : greducer; + if (!reducer && greducer) { + reducer = v_consumers.count(node) ? v_consumers.find(node)->second : reducer; + if (reducer && op_pattern_dict[reducer->op()] != framework::kReduction) { + reducer = nullptr; + } + } + + // node can be inline. + if (CanbeInline(node, consumers, reducer, nodes_in_order.front(), group, nodes_set, this->shape_dict_)) { + // if exist global reduce node. + if (greducer) { + auto loops = ir_sch.GetLoops(GetNodeData(node)->id()); + if (op_pattern_dict[node->op()] == framework::kElementWise) { + ir_sch.FlattenLoops(loops, true); } else { - if (stage->n_out_dims() == master_reducer_stage->n_out_dims() - 1) { - stage->Split(0, stage->GetDimRange(0)); - } - if (stage->n_out_dims() == master_reducer_stage->n_out_dims()) { - std::vector order; - for (int idx = 0; idx < master_reducer_shape.size(); ++idx) { - if (std::find(master_reducer_axes.begin(), master_reducer_axes.end(), idx) == master_reducer_axes.end()) { - order.push_back(idx); - } - } - for (auto axis : master_reducer_axes) { - order.push_back(axis); - } - stage->Reorder(order); - stage->SimpleComputeAt(master_reducer_stage, master_reducer_stage->n_out_dims() - 1); - } else { - stage->ComputeInline(); - } + ir_sch.FlattenLoops(loops, false); } - continue; } - stage->ComputeInline(); - continue; - } - - // if node is kReduction - if (op_pattern_dict[node->op()] == framework::kReduction) { - VLOG(3) << "Reduce Schedule for Reduce Type!"; - // if node is not output node, set buffer. - if (!group->output_nodes.count(node)) { - stage->SetBuffer("local"); - } - // last dimension is not in reduce. - if (without_last_dim) { - // compute at last dimension - if (node == master_reducer) { - stage->SimpleComputeAt(master_stage, master_stage->n_out_dims() - 1); - } else { - // if don't use block shuffle reduce. - if (!tensor_map.count(node_data->id() + "_1")) { - if (reduce_with_same_shape) { - if (master_reducer_stage->n_out_dims() > 1) { - stage->SimpleComputeAt(master_reducer_stage, master_reducer_stage->n_out_dims() - 1); - } - } else { - int num_reduce_axis = master_reducer_stage->tensor()->reduce_axis.size(); - if (master_reducer_stage->n_out_dims() > num_reduce_axis) { - stage->SimpleComputeAt(master_reducer_stage, master_reducer_stage->n_out_dims() - num_reduce_axis - 1); - } - } - } else { - auto stage_1 = stages[tensor_map[node_data->id() + "_0"]]; - auto stage_2 = stages[tensor_map[master_reducer_data->id() + "_0"]]; - // compute at master reducer - if (reduce_with_same_shape) { - stage_1->SimpleComputeAt(stage_2, stage_2->n_out_dims() - 1); - } else { - int num_reduce_axis = stage_2->tensor()->reduce_axis.size(); - stage_1->SimpleComputeAt(stage_2, stage_2->n_out_dims() - num_reduce_axis - 1); - } - // delete stage_1 compute at stage - stage_1->GetComputeAts().erase(stage->id()); - stage->CtrlDepend(tensor_map[master_reducer_data->id() + "_0"]); - // comput at master stage - stage->SimpleComputeAt(master_reducer_stage, master_reducer_stage->n_out_dims() - 1); - } - } - } else { - if (node == master_reducer) { - stage->SimpleComputeAt(master_stage, master_stage->n_out_dims() - 1); - } else if (tensor_map.count(node_data->id() + "_1")) { - auto stage_1 = stages[tensor_map[node_data->id() + "_1"]]; - auto stage_2 = stages[tensor_map[master_reducer_data->id() + "_1"]]; - // compute at master reducer - stage_1->SimpleComputeAt(stage_2, stage_2->n_out_dims() - 1); - // delete stage_1 compute at stage_0 - auto stage_0 = stages[tensor_map[node_data->id() + "_0"]]; - stage_1->GetComputeAts().erase(stage_0->id()); - stage_0->CtrlDepend(tensor_map[master_reducer_data->id() + "_1"]); - - stage->SimpleComputeAt(master_reducer_stage, master_reducer_stage->n_out_dims() - 1); - } else if (tensor_map.count(node_data->id() + "_0")) { - stage->SimpleComputeAt(master_reducer_stage, master_reducer_stage->n_out_dims() - 1); - } else { - LOG(FATAL) << "Error! Unkown Reduce Type, Please Check!"; - } - } + auto block = ir_sch.GetBlock(GetNodeData(node)->id()); + ir_sch.ComputeInline(block); + nodes_inline.insert(node); continue; } - - // if node is internal node or output, try to copy schedule from fellow node - if (group->output_nodes.count(node) || group->internal_nodes.count(node) || sub_group->internal_nodes.count(node)) { - VLOG(3) << "Reduce Schedule for Elementwise Type"; - // if node is not output node, set buffer. - if (!group->output_nodes.count(node)) { - stage->SetBuffer("local"); - } - // node is after reduce - if (this->shape_dict_.at(node_data->id()) == this->shape_dict_.at(master_node_data->id())) { - stage->CopyTransform(master_stage); - stage->CopyLoopInfo(master_stage); - // fringe node with no consumer - stage->SimpleComputeAt(master_stage, master_stage->n_out_dims() - 1); - continue; - } - // node is before reduce. - if (without_last_dim) { - VLOG(3) << "Reduce Schedule for WithoutLastDimInReduce"; - auto reducer_stage = master_reducer_stage; - auto reducer_shape = master_reducer_shape; - auto reducer_data = master_reducer_data; - auto node_shape = this->shape_dict_.at(node_data->id()); - - if (!reduce_with_same_shape) { - // find reducer for current node to assign - GroupPtr reducer_group = sub_group; - if (sub_group->op_pattern_kind != framework::kReduction) { - for (auto& consumer : sub_group->consumer_groups) { - if (!consumer->belong_groups.count(group)) { - continue; - } - if (consumer->op_pattern_kind == framework::kReduction) { - reducer_group = consumer; - break; - } - } - } - - if (reducer_group->op_pattern_kind == framework::kReduction) { - for (auto reducer : reducer_group->master_nodes) { - if (op_pattern_dict[reducer->op()] == framework::kReduction) { - reducer_shape = this->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); - if (node_shape == reducer_shape) { - reducer_data = GetNodeData(reducer); - reducer_stage = stages[tensor_map[reducer_data->id()]]; - break; - } - } - } - } else { - for (auto reducer : group->master_nodes) { - if (op_pattern_dict[reducer->op()] == framework::kReduction) { - reducer_shape = this->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); - if (node_shape == reducer_shape) { - reducer_data = GetNodeData(reducer); - reducer_stage = stages[tensor_map[reducer_data->id()]]; - break; - } - } - } - } - } - CHECK(node_shape == reducer_shape); - - // if used block shuffle reduce - if (tensor_map.count(reducer_data->id() + "_1")) { - ScheduleAssignReduceWithoutLast(stage, reducer_shape, master_reducer_axes); - auto stage_0 = stages[tensor_map[reducer_data->id() + "_0"]]; - if (stage->n_out_dims() < stage_0->n_out_dims()) { - stage->Split(0, stage->GetDimRange(0)); - } - CHECK_EQ(stage->n_out_dims(), stage_0->n_out_dims()) << "stage and stage_0's n_out_dims must be equal!"; - if (reduce_with_same_shape) { - stage->SimpleComputeAt(stage_0, stage_0->n_out_dims() - 1); - } else { - int num_reduce_axis = stage_0->tensor()->reduce_axis.size(); - stage->SimpleComputeAt(stage_0, stage_0->n_out_dims() - num_reduce_axis - 1); - } - } else { - OrderAssignReduce(stage, master_reducer_axes); - if (stage->n_out_dims() < reducer_stage->n_out_dims()) { - stage->Split(0, stage->GetDimRange(0)); - } - CHECK_EQ(stage->n_out_dims(), reducer_stage->n_out_dims()) - << "stage and master_reducer_stage's n_out_dims must be equal!"; - if (reduce_with_same_shape) { - stage->SimpleComputeAt(reducer_stage, reducer_stage->n_out_dims() - 1); - } else { - int num_reduce_axis = reducer_stage->tensor()->reduce_axis.size(); - stage->SimpleComputeAt(reducer_stage, reducer_stage->n_out_dims() - num_reduce_axis - 1); - } - } - } else { - VLOG(3) << "Reduce Schedule for WithLastDimInReduce"; - if (tensor_map.count(master_reducer_data->id() + "_1")) { - ScheduleAssignReduceWithLast(stage, master_reducer_shape, master_reducer_axes); - auto reducer_stage = stages[tensor_map[master_reducer_data->id() + "_1"]]; - if (stage->n_out_dims() < reducer_stage->n_out_dims()) { - stage->Split(0, stage->GetDimRange(0)); - } - CHECK_EQ(stage->n_out_dims(), reducer_stage->n_out_dims()) - << "stage and reducer_stage's n_out_dims must be equal!"; - stage->SimpleComputeAt(reducer_stage, reducer_stage->n_out_dims() - 1); - } else { - // compute at reduce node - auto reducer_stage = stages[tensor_map[master_reducer_data->id() + "_0"]]; - stage->CopyTransform(reducer_stage); - stage->CopyLoopInfo(reducer_stage); - stage->SimpleComputeAt(reducer_stage, reducer_stage->n_out_dims() - 1); - } - } - continue; - } - // others elemenwise internal node use compute-inline - stage->ComputeInline(); - } -} - -void OpLowerer::OutEWiseFusableCompute(poly::StageMap& stages, - std::vector& func_args, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group) { - VLOG(3) << "OutEWiseFusableCompute Group : " << sub_group->group_id; - auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); - auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - - for (auto& node : sub_group->nodes) { - auto node_data = GetNodeData(node); - - std::vector cinn_inputs; - std::vector tensor_inputs = std::move(CollectInputTensor(func_args, tensor_map, node)); - for (auto& tensor : tensor_inputs) { - stages->InsertLazily(tensor); - cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); - } - - std::vector out_types; - std::vector> out_shapes; - - out_types.push_back(this->type_dict_.at(node_data->id())); - out_shapes.push_back(this->shape_dict_.at(node_data->id())); - - auto impl = - OpStrategy::SelectImpl(cinn_strategy[node->op()](node->attrs, tensor_inputs, out_types, out_shapes, target_)); - // do compute - common::CINNValuePack value_pack = impl->fcompute(common::CINNValuePack{cinn_inputs}); - - CHECK_GE(value_pack.size(), 2); - ir::Expr out = value_pack[0]; - poly::StageMap tmp_stages = value_pack.back(); - // node is kReduction - if (op_pattern_dict[node->op()] == framework::kOutFusible) { - // do schedule - value_pack = impl->fschedule(value_pack); - } else if (group->master_nodes.count(node)) { - // node is master node, copy schedule from OutEWiseFusable node - for (auto fnode : group->master_nodes) { - if (op_pattern_dict[fnode->op()] == framework::kOutFusible) { - auto fnode_data = GetNodeData(fnode); - tmp_stages[out.as_tensor_ref()]->CopyTransform(stages[tensor_map[fnode_data->id()]]); - tmp_stages[out.as_tensor_ref()]->CopyLoopInfo(stages[tensor_map[fnode_data->id()]]); - break; - } - } - } - - std::string postfix = ""; - for (auto idx = 0; idx < value_pack.size() - 1; ++idx) { - ir::Expr out = value_pack[idx]; - tensor_map[node_data->id() + postfix] = out.as_tensor_ref(); - stages->InsertLazily(out.as_tensor_ref(), tmp_stages[out.as_tensor_ref()]); - // update postfix - postfix = "_" + std::to_string(idx); - } - } -} - -void OpLowerer::OutEWiseFusableSchedule(poly::StageMap& stages, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group) { - VLOG(3) << "OutEWiseFusableSchedule Group : " << sub_group->group_id; - auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - Node* master_node = nullptr; - for (auto node : group->master_nodes) { - if (op_pattern_dict[node->op()] != framework::kOutFusible) { - master_node = node; - break; - } - } - - // if not find master node, using last kOutFusible as master node. - if (!master_node) { - if (group->fused_sub_groups.empty()) { - master_node = group->nodes.back(); - } else { - master_node = group->fused_sub_groups.back()->nodes.back(); - } - CHECK_EQ(op_pattern_dict[master_node->op()], framework::kOutFusible) << "Master Node Type Must Be OutEWiseFusable!"; - } - - auto master_node_data = GetNodeData(master_node); - auto master_stage = stages[tensor_map[master_node_data->id()]]; - - for (auto& node : sub_group->nodes) { - auto node_data = GetNodeData(node); - auto stage = stages[tensor_map[node_data->id()]]; - // if node is master node. - if (node == master_node) { - continue; - } - - // if node is kOutFusible - if (op_pattern_dict[node->op()] == framework::kOutFusible) { - // if node is not output nodes - if (!group->output_nodes.count(node)) { - tensor_map[node_data->id()]->WithBuffer("local"); - } - // use compute at master node - stage->SimpleComputeAt(master_stage, master_stage->n_out_dims() - 1); - continue; - } - - // if node is internal node or output, try to copy schedule from fellow node - if (group->output_nodes.count(node) || group->internal_nodes.count(node) || sub_group->internal_nodes.count(node)) { - // copy transform from master node - stage->CopyTransform(master_stage); - stage->CopyLoopInfo(master_stage); - - if (!group->output_nodes.count(node)) { - stage->SetBuffer("local"); - } - // fringe node with no consumer - stage->SimpleComputeAt(master_stage, master_stage->n_out_dims() - 1); - continue; - } - // others elemenwise internal node use compute-inline - stage->ComputeInline(); - } -} - -std::vector OpLowerer::LowerNonFusibleOp(GroupPtr& group) { - VLOG(3) << "LowerNonFusibleOp Group : " << group->group_id; - // get input tensor and output tensor - CHECK(group->nodes.size() || group->fused_sub_groups.size()); - auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); - auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - // node - auto node = group->fused_sub_groups.size() ? group->fused_sub_groups[0]->nodes.front() : group->nodes.front(); - // collect input - std::vector func_args; - std::vector tensor_inputs; - std::vector cinn_inputs; - std::unordered_map tensor_map; - for (auto& link : node->inlinks_in_order(true)) { - auto source = link->source(); - CHECK(source); - auto source_data = source->safe_as(); - CHECK(source_data); - - auto id = source_data->id(); - auto shape = this->shape_dict_.at(id); - auto dtype = this->type_dict_.at(id); - - ir::Tensor tensor; - if (!tensor_map.count(id)) { - if (dtype.is_float(32)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_float(64)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_float(16)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_bool()) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_int(8)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_int(16)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_int(32)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_int(64)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_uint(8)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_uint(16)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_uint(32)) { - tensor = lang::Placeholder(id, shape); - } else if (dtype.is_uint(64)) { - tensor = lang::Placeholder(id, shape); + // find master to computeat. + auto master = GetMasterToComputeAt(node, nodes_in_order, nodes_inline, nodes_set, v_consumers, this->shape_dict_); + // assign to reducer/master loop. + if (reducer) { + // if node is vertical with reduce, loop assign reducer. + LoopAssignReduce(ir_sch, node, reducer, this->target_, tensor_map, this->shape_dict_); + } else if (greducer) { + // if node is horizontal with reduce or node is reduce, loop assign master. + auto loops = ir_sch.GetLoops(GetNodeData(node)->id()); + if (op_pattern_dict[node->op()] == framework::kElementWise) { + ir_sch.FlattenLoops(loops, true); + } else if (op_pattern_dict[node->op()] != framework::kReduction) { + ir_sch.FlattenLoops(loops, false); } - tensor_map[id] = tensor; - // recored func input args - func_args.push_back(tensor); - // collect input node data name. - group->input_names.push_back(tensor->name); - } else { - tensor = tensor_map[id]; - } - - tensor_inputs.push_back(tensor); - cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); - } - - std::vector out_types; - std::vector> out_shapes; - - auto node_datas = GetAllNodeData(node); - for (auto node_data : node_datas) { - // collect output node data name. - group->output_names.push_back(node_data->id()); - out_types.push_back(this->type_dict_.at(node_data->id())); - out_shapes.push_back(this->shape_dict_.at(node_data->id())); - } - - auto impl = - OpStrategy::SelectImpl(cinn_strategy[node->op()](node->attrs, tensor_inputs, out_types, out_shapes, target_)); - // if node op is custom_call, apply custom_call compute. - if (node->op()->name == "custom_call") { - std::string external_api; - if (node->attrs.attr_store.count("custom_call")) { - external_api = absl::get(node->attrs.attr_store.at("custom_call")); - } else { - external_api = ExternalApiRegistry::Global()->GetExternalApi(node, target_); - } - std::vector compute_args = {common::CINNValue(group->GetFuncName()), - common::CINNValue(external_api)}; - - common::CINNValuePack pack = impl->fcompute(common::CINNValuePack{compute_args}); - CHECK_EQ(pack.size(), 1UL); - // reset input names as extern api input args can't be remove duplicate. - group->input_names.clear(); - for (auto& inode : node->inlinks_in_order(true)) { - group->input_names.push_back(inode->source()->as()->id()); - } - return {pack[0].operator ir::Expr().as_lowered_func_ref()}; - } - - // do compute - common::CINNValuePack value_pack = impl->fcompute(common::CINNValuePack{cinn_inputs}); - // do schedule - value_pack = impl->fschedule(value_pack); - - CHECK(value_pack.size() >= 2); - poly::StageMap stages = value_pack.back(); - // lazily insert input tensor. - for (auto tensor_input : tensor_inputs) { - stages->InsertLazily(tensor_input); - } - - for (int idx = 0; idx < value_pack.size() - 1; ++idx) { - Expr out = value_pack[idx]; - auto tensor = out.as_tensor_ref(); - // collect output tensor - if (!tensor->buffer.defined() || this->target_ != common::DefaultNVGPUTarget()) { - func_args.push_back(tensor); } + // do loop fuse. + LoopComputeAt(ir_sch, node, master ? master : nodes_in_order.front(), group, this->shape_dict_, tensor_map); } - return lang::LowerVec(group->GetFuncName(), stages, func_args, {}, {}, nullptr, this->target_); + SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map); } } // namespace framework diff --git a/cinn/hlir/framework/op_lowering.h b/cinn/hlir/framework/op_lowering.h old mode 100755 new mode 100644 index 36a22019c0..090fd39932 --- a/cinn/hlir/framework/op_lowering.h +++ b/cinn/hlir/framework/op_lowering.h @@ -72,9 +72,7 @@ class OpLowerer { std::vector LowerWithoutSchedule(GroupPtr& group); private: - std::vector LowerOp(ComputeFunction, ScheduleFunction, GroupPtr&); - std::vector LowerNonFusibleOp(GroupPtr&); - std::vector IRLowerOp(IRComputeFunction, IRScheduleFunction, GroupPtr&); + std::vector IRLowerOp(IRComputeFunction, GroupPtr&); std::vector IRLowerNonFusibleOp(GroupPtr&, bool); std::vector IRLowerOpWithoutSchedule(IRComputeFunction, GroupPtr&); #define DEFINE_IR_COMPUTE_SCHDULE(type) \ @@ -91,29 +89,14 @@ class OpLowerer { Node*& first, \ Node*& second); -#define DEFINE_COMPUTE_SCHDULE(type) \ - void type##Compute(poly::StageMap& stages, \ - std::vector& func_args, \ - std::unordered_map& tensor_map, \ - const GroupPtr& group, \ - const GroupPtr& sub_group); \ - void type##Schedule(poly::StageMap& stages, \ - std::unordered_map& tensor_map, \ - const GroupPtr& group, \ - const GroupPtr& sub_group); - // compute and schedule DEFINE_IR_COMPUTE_SCHDULE(Elementwise); DEFINE_IR_COMPUTE_SCHDULE(Reduce); DEFINE_IR_COMPUTE_SCHDULE(OutEWiseFusable); - DEFINE_COMPUTE_SCHDULE(Elementwise); - DEFINE_COMPUTE_SCHDULE(Reduce); - DEFINE_COMPUTE_SCHDULE(OutEWiseFusable); - - std::vector CollectInputTensor(std::vector& func_args, - std::unordered_map& tensor_map, - const Node* node); + void IRSchedule(ir::IRSchedule& ir_sch, + const GroupPtr& group, + const std::unordered_map& tensor_map); Target target_; const absl::flat_hash_map& type_dict_; diff --git a/cinn/hlir/framework/op_lowering_test.cc b/cinn/hlir/framework/op_lowering_test.cc index 2f736d6eca..38c6cfdb4d 100644 --- a/cinn/hlir/framework/op_lowering_test.cc +++ b/cinn/hlir/framework/op_lowering_test.cc @@ -54,13 +54,7 @@ void CodeGen(ir::LoweredFunc& func) { #endif } -TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_0) { - NetBuilder net_builder("Reduce_Keep_Dim_Fuse_Elementwise_0"); - { - auto A = net_builder.CreateInput(Float(32), {16, 64, 1024}, "A"); - auto B = net_builder.ReduceSum(A, {2}, true); - } - +void Compile(NetBuilder& net_builder) { auto program = net_builder.Build(); auto target = common::DefaultTarget(); RunDecomposer(&program, target); @@ -80,6 +74,283 @@ TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_0) { } } +TEST(OpFusionPass, Reduce_With_Last_Axis_1) { + NetBuilder net_builder("Reduce_With_Last_Axis_1"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {10, 100, 1}, "A"); + auto B = net_builder.ReduceSum(A, {0, 2}); + } + Compile(net_builder); +} + +TEST(OpFusionPass, Reduce_Fuse_Broadcast_With_Output) { + NetBuilder net_builder("Reduce_Fuse_Broadcast_With_Output"); + auto layer_norm_51__tmp_1 = net_builder.CreateInput(Float(32), {256}, "layer_norm_51__tmp_1"); + auto var_3216 = net_builder.CreateInput(Float(32), {256, 60}, "var_3216"); + auto var_3202 = net_builder.CreateInput(Float(32), {1, 60}, "var_3202"); + auto var_3212 = net_builder.CreateInput(Float(32), {256, 60}, "var_3212"); + + auto var_3206 = net_builder.Reshape(layer_norm_51__tmp_1, {256, 1}); + auto composite_tmp_8 = net_builder.FillConstant({256, 1}, 1e-5, "composite_tmp_8"); + auto var_3214 = net_builder.Add(var_3206, composite_tmp_8); + auto composite_tmp_10 = net_builder.FillConstant({256, 1}, 1.0, "composite_tmp_10"); + auto var_3220 = net_builder.Divide(composite_tmp_10, var_3214); + auto var_3226 = net_builder.Sqrt(var_3220); + auto var_3224 = net_builder.Scale(var_3220, -1.0, 0.0, true); + auto var_3366 = net_builder.BroadcastTo(var_3224, {256, 60}); + auto var_3228 = net_builder.Multiply(var_3366, var_3216); + auto var_3368 = net_builder.BroadcastTo(var_3202, {256, 60}); + auto var_3236 = net_builder.Multiply(var_3228, var_3212); + auto var_3244 = net_builder.Multiply(var_3236, var_3368); + auto var_3252 = net_builder.ReduceSum(var_3244, {1}, true); + auto var_3232 = net_builder.Scale(var_3226, 0.0166667, 0.0, true); + + Compile(net_builder); +} + +TEST(OpFusionPass, Reduce_Fuse_Broadcast_Layernorm) { + int h = 32, w = 1024; + NetBuilder net_builder("Reduce_Fuse_Broadcast_Layernorm"); + // create model + { + // x + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + // x * x + auto B = net_builder.Multiply(A, A); + // sum x + auto C = net_builder.ReduceSum(A, {1}); + // sum x*x + auto D = net_builder.ReduceSum(B, {1}); + // constant w + auto E = net_builder.FillConstant({h}, 1024.0f, "E"); + // mean + auto F = net_builder.Divide(C, E); + auto FF = net_builder.BroadcastTo(F, {h, w}, {0}); + // mean x*x + auto G = net_builder.Divide(D, E); + // mean * mean + auto H = net_builder.Multiply(F, F); + // var^2 + auto I = net_builder.Subtract(G, H); + // eps + auto J = net_builder.FillConstant({h}, 1e-10f, "J"); + // eps + delta + auto K = net_builder.Add(I, J); + // var + auto L = net_builder.Sqrt(K); + auto LL = net_builder.BroadcastTo(L, {h, w}, {0}); + // x - mean + auto M = net_builder.Subtract(A, FF); + // /var + auto N = net_builder.Divide(M, LL); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Reduce_Fuse_Broadcast_Softmax) { + int h = 32, w = 1024; + NetBuilder net_builder("Reduce_Fuse_Broadcast_Softmax"); + // create model + { + // softmax + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + // redece max + auto B = net_builder.ReduceMax(A, {1}); + // broadcast + auto C = net_builder.BroadcastTo(B, {h, w}, {0}); + // x - max(x) + auto D = net_builder.Subtract(A, C); + // exp(x) + auto E = net_builder.Exp(D); + // reduce sum + auto F = net_builder.ReduceSum(E, {1}); + // broadcast + auto G = net_builder.BroadcastTo(F, {h, w}, {0}); + // exp(x)/sum(exp(x)) + auto H = net_builder.Divide(E, G); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Reduce_Fuse_Broadcast_1) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Fuse_Broadcast_1"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h * w}, "A"); + auto B = net_builder.ReduceSum(A, {0}); + auto C = net_builder.BroadcastTo(B, {h * w}, {0}); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Reduce_Fuse_Broadcast_2) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Fuse_Broadcast_2"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.ReduceSum(A, {0, 1}); + auto C = net_builder.BroadcastTo(B, {h, w}, {1}); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Reduce_Fuse_Broadcast_3) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Fuse_Broadcast_3"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2}); + auto C = net_builder.BroadcastTo(B, {h, h, w}, {0}); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Reduce_Fuse_Broadcast_4) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Fuse_Broadcast_4"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2}); + auto C = net_builder.BroadcastTo(B, {h, h, w}, {1}); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Reduce_Fuse_Broadcast_5) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Fuse_Broadcast_5"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2}); + auto C = net_builder.BroadcastTo(B, {h, h, w}, {0}); + auto D = net_builder.ReduceSum(C, {1, 2}); + auto E = net_builder.BroadcastTo(D, {h, h, w}, {0}); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Reduce_Fuse_Broadcast_6) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Fuse_Broadcast_6"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2}); + auto C = net_builder.BroadcastTo(B, {h, h, w}, {0}); + auto D = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto E = net_builder.BroadcastTo(D, {h, h, w}, {1, 2}); + auto F = net_builder.Add(C, E); + } + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Dim_Equal_One_0) { + NetBuilder net_builder("Reduce_Dim_Equal_One_0"); + { + auto A = net_builder.CreateInput(Float(32), {1, 1000}, "A"); + auto B = net_builder.CreateInput(Float(32), {1, 1000}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {1}, false); + auto E = net_builder.ReduceSum(C, {1}, false); + auto F = net_builder.Add(D, E); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Dim_Equal_One_1) { + NetBuilder net_builder("Reduce_Dim_Equal_One_1"); + { + auto A = net_builder.CreateInput(Float(32), {32, 32}, "A"); + auto B = net_builder.ReduceSum(A, {0, 1}, false); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Dim_Equal_One_2) { + NetBuilder net_builder("Reduce_Dim_Equal_One_2"); + { + auto A = net_builder.CreateInput(Float(32), {32, 1024}, "A"); + auto B = net_builder.ReduceSum(A, {1}, false); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Dim_Equal_One_3) { + NetBuilder net_builder("Reduce_Dim_Equal_One_3"); + { + auto A = net_builder.CreateInput(Float(32), {32, 1024}, "A"); + auto B = net_builder.ReduceSum(A, {0, 1}, false); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Dim_Equal_One_4) { + NetBuilder net_builder("Reduce_Dim_Equal_One_4"); + { + auto A = net_builder.CreateInput(Float(32), {32, 32, 1024}, "A"); + auto B = net_builder.ReduceSum(A, {0, 2}, false); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Dim_Equal_One_5) { + NetBuilder net_builder("Reduce_Dim_Equal_One_5"); + { + auto A = net_builder.CreateInput(Float(32), {32, 32, 32, 256}, "A"); + auto B = net_builder.ReduceSum(A, {0, 2, 3}, false); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Dim_Equal_One_6) { + NetBuilder net_builder("Reduce_Dim_Equal_One_6"); + { + auto A = net_builder.CreateInput(Float(32), {32, 32, 256}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2}); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Dim_Equal_One_7) { + NetBuilder net_builder("Reduce_Dim_Equal_One_7"); + { + auto A = net_builder.CreateInput(Float(32), {1, 1, 1024}, "A"); + auto B = net_builder.ReduceSum(A, {2}, false); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_0) { + NetBuilder net_builder("Reduce_Keep_Dim_Fuse_Elementwise_0"); + { + auto A = net_builder.CreateInput(Float(32), {16, 64, 1024}, "A"); + auto B = net_builder.ReduceSum(A, {2}, true); + } + + Compile(net_builder); +} + TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_1) { NetBuilder net_builder("Reduce_Keep_Dim_Fuse_Elementwise_1"); { @@ -89,23 +360,7 @@ TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_1) { auto D = net_builder.Add(B, C); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_2) { @@ -117,23 +372,7 @@ TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_2) { auto D = net_builder.Add(B, C); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_3) { @@ -143,23 +382,7 @@ TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_3) { auto B = net_builder.ReduceSum(A, {2, 3}, true); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_4) { @@ -169,23 +392,7 @@ TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_4) { auto B = net_builder.ReduceSum(A, {2}, true); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_5) { @@ -195,23 +402,7 @@ TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_5) { auto B = net_builder.ReduceSum(A, {2}, true); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_6) { @@ -221,23 +412,7 @@ TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_6) { auto B = net_builder.ReduceSum(A, {2}, true); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_7) { @@ -247,23 +422,7 @@ TEST(OP_LOWERING, Reduce_Keep_Dim_Fuse_Elementwise_7) { auto B = net_builder.ReduceSum(A, {1, 3}, true); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Elementwise_Test_Concat_Before_Reduce) { @@ -276,23 +435,7 @@ TEST(OP_LOWERING, Elementwise_Test_Concat_Before_Reduce) { auto E = net_builder.ReduceSum(D, {2}, false); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Elementwise_Test_Reshape_Before_Reduce) { @@ -307,23 +450,7 @@ TEST(OP_LOWERING, Elementwise_Test_Reshape_Before_Reduce) { auto G = net_builder.ReduceSum(F, {0, 1}, false); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Elementwise_Test_Reshape_After_Reduce) { @@ -339,23 +466,7 @@ TEST(OP_LOWERING, Elementwise_Test_Reshape_After_Reduce) { auto H = net_builder.Add(B, G); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Elementwise_Test_Reshape_Fuse_Concat) { @@ -373,23 +484,7 @@ TEST(OP_LOWERING, Elementwise_Test_Reshape_Fuse_Concat) { auto I = net_builder.Concat({G, H}, 2); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Elementwise_TEST_Split_0) { @@ -399,23 +494,7 @@ TEST(OP_LOWERING, Elementwise_TEST_Split_0) { auto B = net_builder.Split(A, {3, 5, 16, 2, 6}, 0); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Elementwise_TEST_Split_1) { @@ -425,23 +504,7 @@ TEST(OP_LOWERING, Elementwise_TEST_Split_1) { auto B = net_builder.Split(A, {32, 32, 32, 32}, 1); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Elementwise_TEST_Split_2) { @@ -451,23 +514,7 @@ TEST(OP_LOWERING, Elementwise_TEST_Split_2) { auto B = net_builder.Split(A, {64, 32, 32}, 1); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Elementwise_TEST_0) { @@ -478,23 +525,7 @@ TEST(OP_LOWERING, Elementwise_TEST_0) { auto o2 = net_builder.Scale(x, -1.0, 0.0); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, NonFusibleOp_TEST_0) { @@ -504,22 +535,7 @@ TEST(OP_LOWERING, NonFusibleOp_TEST_0) { auto B = net_builder.Reshape(A, {9801, 2}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, NonFusibleOp_TEST_1) { @@ -530,22 +546,7 @@ TEST(OP_LOWERING, NonFusibleOp_TEST_1) { auto C = net_builder.Matmul(A, B); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, NonFusibleOp_TEST_2) { @@ -555,22 +556,7 @@ TEST(OP_LOWERING, NonFusibleOp_TEST_2) { auto B = net_builder.Matmul(A, A); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, NonFusibleOp_TEST_3) { @@ -580,52 +566,21 @@ TEST(OP_LOWERING, NonFusibleOp_TEST_3) { auto C = net_builder.Split(A, {4}, 1); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } -} - -#ifdef CINN_WITH_CUDA -TEST(OP_LOWERING, NonFusibleOp_TEST_4) { - NetBuilder net_builder("NonFusibleOp_TEST_4"); - { - auto A = net_builder.CreateInput(Float(32), {128, 128}, "A"); - auto B = net_builder.CreateInput(Float(32), {128, 128}, "B"); - auto C = net_builder.CreateInput(Float(32), {128, 128}, "C"); - auto D = net_builder.Matmul(A, B); - auto E = net_builder.Add(C, D); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "TransToCustomCallPass"); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); + Compile(net_builder); +} + +#ifdef CINN_WITH_CUDA +TEST(OP_LOWERING, NonFusibleOp_TEST_4) { + NetBuilder net_builder("NonFusibleOp_TEST_4"); + { + auto A = net_builder.CreateInput(Float(32), {128, 128}, "A"); + auto B = net_builder.CreateInput(Float(32), {128, 128}, "B"); + auto C = net_builder.CreateInput(Float(32), {128, 128}, "C"); + auto D = net_builder.Matmul(A, B); + auto E = net_builder.Add(C, D); } + + Compile(net_builder); } #endif @@ -638,22 +593,7 @@ TEST(OP_LOWERING, Transform_TEST_0) { auto D = net_builder.Concat({A, B, C}, 1); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Elementwise_Test_0) { @@ -670,55 +610,7 @@ TEST(OP_LOWERING, Elementwise_Test_0) { auto G = net_builder.Add(E, F); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } -} - -TEST(OP_LOWERING, Elementwise_Test_1) { - int h = 32, w = 32; - NetBuilder net_builder("Elementwise_Test_1"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.Add(A, B); - auto F = net_builder.Add(E, C); - auto G = net_builder.Add(E, D); - auto H = net_builder.Add(F, G); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Elementwise_Test_2) { @@ -736,22 +628,7 @@ TEST(OP_LOWERING, Elementwise_Test_2) { auto H = net_builder.Add(F, G); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Test_0) { @@ -763,22 +640,7 @@ TEST(OP_LOWERING, Reduce_Test_0) { auto B = net_builder.ReduceSum(A, {0, 1}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Test_1) { @@ -790,22 +652,7 @@ TEST(OP_LOWERING, Reduce_Test_1) { auto B = net_builder.ReduceSum(A, {0, 1, 2}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Test_2) { @@ -817,22 +664,7 @@ TEST(OP_LOWERING, Reduce_Test_2) { auto B = net_builder.ReduceSum(A, {0, 1}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Test_3) { @@ -844,22 +676,7 @@ TEST(OP_LOWERING, Reduce_Test_3) { auto B = net_builder.ReduceSum(A, {0, 1, 2}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Test_4) { @@ -871,22 +688,7 @@ TEST(OP_LOWERING, Reduce_Test_4) { auto B = net_builder.ReduceSum(A, {0}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Test_5) { @@ -898,22 +700,7 @@ TEST(OP_LOWERING, Reduce_Test_5) { auto B = net_builder.ReduceSum(A, {0}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Test_6) { @@ -925,22 +712,7 @@ TEST(OP_LOWERING, Reduce_Test_6) { auto B = net_builder.ReduceSum(A, {0}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Test_7) { @@ -952,22 +724,7 @@ TEST(OP_LOWERING, Reduce_Test_7) { auto B = net_builder.ReduceSum(A, {1}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Test_8) { @@ -979,49 +736,19 @@ TEST(OP_LOWERING, Reduce_Test_8) { auto B = net_builder.ReduceSum(A, {1, 2}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Test_9) { int n = 16, c = 128, h = 56, w = 56; NetBuilder net_builder("Reduce_Test_9"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {n, c, h, w}, "A"); - auto B = net_builder.ReduceSum(A, {0, 2, 3}); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); + // create model + { + auto A = net_builder.CreateInput(Float(32), {n, c, h, w}, "A"); + auto B = net_builder.ReduceSum(A, {0, 2, 3}); } + + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Test_10) { @@ -1033,22 +760,7 @@ TEST(OP_LOWERING, Reduce_Test_10) { auto B = net_builder.ReduceSum(A, {1}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_0) { @@ -1063,23 +775,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_0) { auto D = net_builder.Add(B, C); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_1) { @@ -1090,26 +786,10 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_1) { auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); auto D = net_builder.Add(A, B); - auto E = net_builder.ReduceSum(D, {1}); + auto E = net_builder.ReduceSum(D, {0}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_2) { @@ -1128,23 +808,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_2) { auto H = net_builder.Add(E, G); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_3) { @@ -1158,23 +822,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_3) { auto D = net_builder.Add(B, C); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_4) { @@ -1190,26 +838,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_4) { auto E = net_builder.ReduceSum(C, {0}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 3); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_5) { @@ -1225,26 +854,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_5) { auto E = net_builder.ReduceSum(C, {1}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 3); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_6) { @@ -1262,26 +872,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_6) { auto I = net_builder.Add(F, C); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 3); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_7) { @@ -1299,26 +890,160 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_7) { auto I = net_builder.Add(F, C); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); + Compile(net_builder); +} - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 5); +TEST(OP_LOWERING, Reduce_Fuse_Reduce_Test_0) { + int h = 128, w = 128; + NetBuilder net_builder("Reduce_Fuse_Reduce_Test_0"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto D = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(D, {0}); + auto F = net_builder.ReduceSum(D, {0}); + } - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); + Compile(net_builder); +} - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); +TEST(OP_LOWERING, Reduce_Fuse_Reduce_Test_1) { + int h = 128, w = 1024; + NetBuilder net_builder("Reduce_Fuse_Reduce_Test_1"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto D = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(D, {0}); + auto F = net_builder.ReduceSum(D, {0}); + } - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Fuse_Reduce_Test_2) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Fuse_Reduce_Test_2"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto D = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(D, {0}); + auto F = net_builder.ReduceSum(D, {0}); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Fuse_Reduce_Test_3) { + int h = 32, w = 1024; + NetBuilder net_builder("Reduce_Fuse_Reduce_Test_3"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto D = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(D, {1}); + auto F = net_builder.ReduceSum(D, {1}); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Fuse_Reduce_Test_4) { + int h = 32, w = 1024; + NetBuilder net_builder("Reduce_Fuse_Reduce_Test_4"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto D = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(D, {0, 1}); + auto F = net_builder.ReduceSum(D, {0, 1}); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Fuse_Reduce_Test_5) { + int h = 32, w = 128; + NetBuilder net_builder("Reduce_Fuse_Reduce_Test_5"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto D = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(D, {0, 1}); + auto F = net_builder.ReduceSum(D, {0, 1}); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Fuse_Reduce_Test_6) { + int h = 32, w = 128; + NetBuilder net_builder("Reduce_Fuse_Reduce_Test_6"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, h, w}, "B"); + auto D = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(D, {0, 2}); + auto F = net_builder.ReduceSum(D, {0, 2}); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Fuse_Reduce_Test_7) { + int h = 32, w = 128; + NetBuilder net_builder("Reduce_Fuse_Reduce_Test_7"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(D, {0, 1}); + auto F = net_builder.ReduceSum(D, {0, 1}); + auto G = net_builder.Add(E, C); + auto H = net_builder.Add(F, C); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Fuse_Reduce_Test_8) { + int h = 32, w = 128; + NetBuilder net_builder("Reduce_Fuse_Reduce_Test_8"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h + h, w}, "B"); + auto E = net_builder.ReduceSum(A, {0}); + auto F = net_builder.ReduceSum(B, {0}); + auto G = net_builder.Add(E, F); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Fuse_Reduce_Test_9) { + int h = 32, w = 1024; + NetBuilder net_builder("Reduce_Fuse_Reduce_Test_9"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h * 2, w}, "B"); + auto E = net_builder.ReduceSum(A, {0}); + auto F = net_builder.ReduceSum(B, {0}); + auto G = net_builder.Add(E, F); } + + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_8) { @@ -1336,26 +1061,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_8) { auto I = net_builder.Add(F, C); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 5); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_9) { @@ -1373,26 +1079,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_9) { auto I = net_builder.Add(F, C); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 5); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_10) { @@ -1406,26 +1093,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_10) { auto D = net_builder.Add(B, C); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_11) { @@ -1440,26 +1108,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_11) { auto F = net_builder.ReduceSum(D, {0, 2, 3}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 3); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_12) { @@ -1474,29 +1123,9 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_12) { auto F = net_builder.ReduceSum(D, {0, 2, 3}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 3); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - CodeGen(lowered_func[0]); - } + Compile(net_builder); } -/* -TODO:exist coredump. + TEST(OP_LOWERING, Reduce_Fusion_Test_13) { int n = 8, c = 8, h = 8, w = 8; NetBuilder net_builder("Reduce_Fusion_Test_13"); @@ -1509,29 +1138,8 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_13) { auto F = net_builder.ReduceSum(D, {0, 1, 2}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 3); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - LOG(INFO) << lowered_func[0]; - CodeGen(lowered_func[0]); - } + Compile(net_builder); } -*/ TEST(OP_LOWERING, Reduce_Fusion_Test_14) { int n = 8, c = 8, h = 8, w = 8; @@ -1545,27 +1153,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_14) { auto F = net_builder.ReduceSum(D, {0, 3, 4}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 3); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - LOG(INFO) << lowered_func[0]; - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_15) { @@ -1580,28 +1168,9 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_15) { auto F = net_builder.ReduceSum(D, {0}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 3); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - LOG(INFO) << lowered_func[0]; - CodeGen(lowered_func[0]); - } + Compile(net_builder); } + TEST(OP_LOWERING, Reduce_Fusion_Test_16) { int n = 128, c = 128, h = 28, w = 28; NetBuilder net_builder("Reduce_Fusion_Test_16"); @@ -1614,27 +1183,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_16) { auto F = net_builder.ReduceSum(D, {0, 2, 3}); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 3); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - LOG(INFO) << lowered_func[0]; - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_17) { @@ -1649,24 +1198,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_17) { auto G = net_builder.Add(E, F); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - LOG(INFO) << lowered_func[0]; - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_18) { @@ -1681,27 +1213,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_18) { auto G = net_builder.Add(E, F); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - // hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - // CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - LOG(INFO) << lowered_func[0]; - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_19) { @@ -1716,27 +1228,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_19) { auto G = net_builder.Add(E, F); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - // hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - // CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - LOG(INFO) << lowered_func[0]; - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_20) { @@ -1757,33 +1249,42 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_20) { auto K = net_builder.Add(H, J); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - // hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - // CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - LOG(INFO) << lowered_func[0]; - CodeGen(lowered_func[0]); - } + Compile(net_builder); } TEST(OP_LOWERING, Reduce_Fusion_Test_21) { int h = 128, w = 4; NetBuilder net_builder("Reduce_Fusion_Test_21"); // create model + { + auto A0 = net_builder.CreateInput(Float(32), {256, w}, "A0"); + auto B0 = net_builder.CreateInput(Float(32), {256, w}, "B0"); + auto C0 = net_builder.CreateInput(Float(32), {55200, w}, "C0"); + auto D0 = net_builder.CreateInput(Float(32), {2750, w}, "D0"); + auto A1 = net_builder.CreateInput(Float(32), {256, w}, "A1"); + auto B1 = net_builder.CreateInput(Float(32), {256, w}, "B1"); + auto C1 = net_builder.CreateInput(Float(32), {55200, w}, "C1"); + auto D1 = net_builder.CreateInput(Float(32), {2750, w}, "D1"); + auto AA = net_builder.Add(A0, A1); + auto BB = net_builder.Add(B0, B1); + auto CC = net_builder.Add(C0, C1); + auto DD = net_builder.Add(D0, D1); + auto E = net_builder.ReduceSum(AA, {0}); + auto F = net_builder.ReduceSum(BB, {0}); + auto G = net_builder.ReduceSum(CC, {0}); + auto H = net_builder.ReduceSum(DD, {0}); + auto I = net_builder.Add(E, F); + auto J = net_builder.Add(G, I); + auto K = net_builder.Add(H, J); + } + + Compile(net_builder); +} + +TEST(OP_LOWERING, Reduce_Fusion_Test_22) { + int h = 128, w = 4; + NetBuilder net_builder("Reduce_Fusion_Test_22"); + // create model { auto A0 = net_builder.CreateInput(Float(32), {256, w}, "A0"); auto B0 = net_builder.CreateInput(Float(32), {256, w}, "B0"); @@ -1810,27 +1311,7 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_21) { auto DDD = net_builder.Add(DD, D1); } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - RunDecomposer(&program, target); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 9); - - hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - CHECK_EQ(graph->fusion_groups.size(), 1); - - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, target); - for (auto& fusion_op : graph->fusion_groups) { - auto lowered_func = op_lowerer.Lower(fusion_op); - CHECK_EQ(lowered_func.size(), 1); - LOG(INFO) << lowered_func[0]; - CodeGen(lowered_func[0]); - } + Compile(net_builder); } } // namespace framework diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc new file mode 100644 index 0000000000..06079ef3d3 --- /dev/null +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -0,0 +1,1377 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/framework/op_lowering_util.h" +#ifdef CINN_WITH_CUDA +#include "cinn/runtime/cuda/float16.h" +#endif +#include + +namespace cinn { +namespace hlir { +namespace framework { + +std::vector GetInputNodeData(const Node* node) { + std::vector producers; + for (auto& link : node->inlinks_in_order(true)) { + auto node_data = link->source()->safe_as(); + producers.push_back(node_data); + } + return producers; +} + +ir::Tensor GetTensor(const NodeData* node_data, + const absl::flat_hash_map& type_dict, + const absl::flat_hash_map& shape_dict) { + auto dtype = type_dict.at(node_data->id()); + if (dtype.is_float(32)) { + return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + } else if (dtype.is_float(64)) { + return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + } else if (dtype.is_float(16)) { + return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + } else if (dtype.is_bool()) { + return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + } else if (dtype.is_int(8)) { + return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + } else if (dtype.is_int(16)) { + return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + } else if (dtype.is_int(32)) { + return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + } else if (dtype.is_int(64)) { + return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + } else if (dtype.is_uint(8)) { + return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + } else if (dtype.is_uint(16)) { + return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + } else if (dtype.is_uint(32)) { + return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + } else if (dtype.is_uint(64)) { + return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + } else { + LOG(FATAL) << "Unsupport dtype: " << dtype; + } +} + +std::vector CollectInputTensor(const Node* node, + std::vector& func_args, + std::unordered_map& tensor_map, + const absl::flat_hash_map& type_dict, + const absl::flat_hash_map& shape_dict) { + std::vector tensors; + // get all input nodes + for (auto& node_data : GetInputNodeData(node)) { + CHECK(node_data); + auto tensor = GetTensor(node_data, type_dict, shape_dict); + if (!tensor_map.count(node_data->id())) { + tensor_map[node_data->id()] = tensor; + // record func input args + func_args.push_back(tensor); + } + tensors.push_back(tensor); + } + return tensors; +} + +NodeData* GetNodeData(const Node* node) { + auto node_data = (*node->outlinks().begin())->sink()->safe_as(); + CHECK(node_data); + return node_data; +} + +std::vector GetAllNodeData(const Node* node) { + std::vector node_datas; + for (auto& link : node->outlinks_in_order(true)) { + auto node_data = link->sink()->safe_as(); + CHECK(node_data); + node_datas.push_back(node_data); + } + + return node_datas; +} + +std::vector GetConsumers(const Node* node) { + std::vector consumers; + auto node_data = GetNodeData(node); + for (auto& link : node_data->outlinks()) { + auto consumer = link->sink()->safe_as(); + CHECK(consumer); + consumers.push_back(consumer); + } + return consumers; +} + +std::vector GetConsumersInSet(const Node* node, const std::unordered_set& node_set) { + std::vector consumers; + auto node_data = GetNodeData(node); + for (auto& link : node_data->outlinks()) { + auto consumer = link->sink()->safe_as(); + CHECK(consumer); + if (node_set.count(consumer)) { + consumers.push_back(consumer); + } + } + return consumers; +} + +std::vector GetProducers(const Node* node) { + std::vector producers; + for (auto& link : node->inlinks_in_order(true)) { + auto data = link->source()->safe_as(); + CHECK(data); + if (data->source_node.get()) { + producers.push_back(data->source_node.get()); + } + } + return producers; +} + +std::vector GetProducersInSet(const Node* node, const std::unordered_set& node_set) { + std::vector producers; + for (auto& link : node->inlinks_in_order(true)) { + auto data = link->source()->safe_as(); + CHECK(data); + if (data->source_node.get() && node_set.count(data->source_node.get())) { + producers.push_back(data->source_node.get()); + } + } + return producers; +} + +bool IsConstOp(const framework::Node* node) { + static std::unordered_set const_op_type = {"const_scalar", "fill_constant", "arange"}; + if (const_op_type.count(node->op()->name)) { + return true; + } else { + return false; + } +} + +std::vector GetInputShape(const Node* node, const absl::flat_hash_map& shape_dict) { + auto producers = GetProducers(node); + CHECK(producers.size()); + + auto producer_data = GetNodeData(producers.front()); + return shape_dict.at(producer_data->id()); +} + +std::vector GetOutputShape(const Node* node, const absl::flat_hash_map& shape_dict) { + auto node_data = GetNodeData(node); + return shape_dict.at(node_data->id()); +} + +Node* FindGlobalReducer(const std::vector& nodes_in_order) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + for (auto iter = nodes_in_order.rbegin(); iter != nodes_in_order.rend(); ++iter) { + if (op_pattern_dict[(*iter)->op()] == framework::kReduction) { + return *iter; + } + } + + return nullptr; +} + +using Visitor = std::function(const Node*, const std::unordered_set&)>; +Node* FindReducerInRoute(const Node* node, const std::unordered_set& nodes_set, Visitor visitor) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + std::queue candidates; + candidates.push(node); + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + for (auto consumer : visitor(candidate, nodes_set)) { + if (op_pattern_dict[consumer->op()] == framework::kReduction) { + return consumer; + } + candidates.push(consumer); + } + } + + return nullptr; +} + +Node* FindNearestReducer(const Node* node, const std::unordered_set& nodes_set) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + // from consumers find reducer. + auto reducer = FindReducerInRoute(node, nodes_set, GetConsumersInSet); + if (reducer) + return reducer; + else + return FindReducerInRoute(node, nodes_set, GetProducersInSet); +} + +std::unordered_map BuildVirtualConsumer(const GroupPtr& group, + const absl::flat_hash_map& shape_dict) { + std::unordered_map virtual_consumers; + std::unordered_set nodes_set = group->NodeSet(); + if (group->op_pattern_kind != framework::kReduction) { + return virtual_consumers; + } + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + Node* g_node = nullptr; + for (auto t_node : group->output_nodes) { + if (op_pattern_dict[t_node->op()] == framework::kReduction) { + continue; + } + // producer exits reduce-sum and not consumers. + if (FindReducerInRoute(t_node, nodes_set, GetProducersInSet) && GetConsumersInSet(t_node, nodes_set).size() == 0) { + g_node = t_node; + break; + } + } + + // try to find reducer with different shape. + for (auto t_node : group->output_nodes) { + if (op_pattern_dict[t_node->op()] == framework::kReduction) { + if (g_node) { + virtual_consumers[t_node] = g_node; + } + continue; + } + if (FindNearestReducer(t_node, nodes_set)) { + continue; + } + + std::unordered_set visited; + std::queue candidates; + + candidates.push(t_node); + visited.insert(t_node); + // from producers find reducer consumer. + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + for (auto producer : GetProducersInSet(candidate, nodes_set)) { + if (visited.count(producer)) { + continue; + } + + auto reducer = FindReducerInRoute(producer, nodes_set, GetConsumersInSet); + if (reducer) { + virtual_consumers[t_node] = reducer; + break; + } + visited.insert(producer); + } + // if find horizontal reducer. + if (virtual_consumers.count(t_node)) { + break; + } + } + + if (virtual_consumers.count(t_node)) { + continue; + } + + if (t_node != g_node && g_node) { + virtual_consumers[t_node] = g_node; + } + } + return virtual_consumers; +} + +std::vector FindConsumers(Node* node, + const std::unordered_set& nodes_set, + const std::unordered_map& virtual_consumers) { + auto consumers = GetConsumersInSet(node, nodes_set); + if (virtual_consumers.count(node)) { + consumers.push_back(virtual_consumers.find(node)->second); + } + return consumers; +} + +std::vector TopologicalOrder(const GroupPtr& group, const std::unordered_map& virtual_consumers) { + std::vector nodes_in_order; + std::unordered_set nodes_set = group->NodeSet(); + + while (!nodes_set.empty()) { + auto tmp_node_set = nodes_set; + for (auto node : tmp_node_set) { + auto consumers = FindConsumers(node, nodes_set, virtual_consumers); + bool cant_be_erase = false; + for (auto consumer : consumers) { + if (nodes_set.count(consumer)) { + cant_be_erase = true; + break; + } + } + + if (cant_be_erase) continue; + nodes_in_order.push_back(node); + nodes_set.erase(node); + } + } + + return nodes_in_order; +} + +bool WithoutLastDimInReduce(const std::vector& shape, const std::vector& axes) { + if (axes.empty()) { + return false; + } + // if last axis is in reduce. + if (std::find(axes.begin(), axes.end(), shape.size() - 1) != axes.end() || + std::find(axes.begin(), axes.end(), -1) != axes.end()) { + return false; + } + + int sum_last_axes = 1; + for (int idx = axes.back() + 1; idx < shape.size(); ++idx) { + sum_last_axes *= shape[idx]; + } + + if (sum_last_axes > 1) { + return true; + } else { + return false; + } +} + +void LoopOrderAssignReduce(ir::IRSchedule& ir_sch, + const std::string& block_name, + const std::vector& axes, + const common::Target& target, + const bool just_reorder = false) { + // reorder none-last reduce axis to last. + // like: shape = [16,16,16,16,16],axes = [1,3] -> new order = [0, 2, 4, 1, 3]. + std::vector order; + int n_out_dims = ir_sch.GetLoops(block_name).size(); + for (int idx = 0; idx < n_out_dims; ++idx) { + if (std::find(axes.begin(), axes.end(), idx) == axes.end()) { + order.push_back(idx); + } + } + for (auto axis : axes) { + order.push_back(axis); + } + ir_sch.Reorder(ir_sch.GetBlock(block_name), order); + + if (just_reorder) { + return; + } + // fuse others none-reduce axis. + int last_dimension_num = n_out_dims - axes.back() - 1; + int index = n_out_dims - last_dimension_num - axes.size(); + + // fuse last_dimension_num - 1 times + for (auto idx = index; idx < index + last_dimension_num - 1; ++idx) { + ir_sch.Fuse(block_name, {index, index + 1}); + } + + auto loops = ir_sch.GetLoops(block_name); + auto psize = ir::GetLoopExtent(loops[index]); + + if (psize > target.max_num_threads()) { + for (int idx = target.max_num_threads(); idx > 0; --idx) { + if (psize % idx == 0) { + ir_sch.Split(loops[index], {-1, idx}); + break; + } + CHECK_GT(idx, 1); + } + } + + // fuse index - 1 times + for (int idx = 0; idx < index - 1; ++idx) { + ir_sch.Fuse(block_name, {0, 1}); + } +} + +void LoopAssignReduceWithoutLast(ir::IRSchedule& ir_sch, + const std::string& block_name, + const std::vector& inshape, + const std::vector& axes, + const common::Target& target) { + CHECK(axes.size()); + int lane = 1; + int max_num_threads = target.max_num_threads(); + for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { + lane *= inshape[idx]; + } + CHECK_LE(lane, max_num_threads / 2) << "Parallel threads must less equal max_num_threads/2 on gpu!"; + int pos = 0; + int index = axes.size() - 1; + for (; index >= 0; --index) { + if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { + pos = axes[index + 1]; + break; + } + + lane *= inshape[axes[index]]; + if (lane > max_num_threads / 2) { + pos = axes[index]; + break; + } + + if (index == 0) { + pos = axes[0]; + } + } + + if (lane > max_num_threads / 2) { + int prefix = inshape[axes[index]]; + int tail = lane / prefix; + for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; --idx) { + if (prefix % idx == 0) { + ir_sch.Split(block_name, axes[index], {-1, idx}); + break; + } + CHECK_GT(idx - 1, (max_num_threads / 2) / tail) << "idx should greater than (max_num_threads / 2) / tail."; + } + } + + // insert 1 + for (int idx = 0; idx < axes.size() - 1 - index; ++idx) { + auto loops = ir_sch.GetLoops(block_name); + ir_sch.Split(block_name, pos, {-1, ir::GetLoopExtent(loops[pos])}); + } + LoopOrderAssignReduce(ir_sch, block_name, axes, target); + // return insert 1 + int start_index = ir_sch.GetLoops(block_name).size() - axes.size(); + for (int idx = 0; idx < axes.size(); ++idx) { + auto loops = ir_sch.GetLoops(block_name); + if (ir::GetLoopExtent(loops[start_index]) == 1) { + ir_sch.Fuse({loops[start_index - 1], loops[start_index]}); + } else { + ++start_index; + } + } +} + +void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, + const std::string& block_name, + const std::vector& inshape, + const std::vector& axes, + const common::Target& target) { + // find first reduce and second reduce axis. + int lane = 1; + int index = static_cast(axes.size()) - 1; + auto max_num_threads = target.max_num_threads(); + for (; index >= 0; --index) { + if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { + break; + } + lane *= inshape[axes[index]]; + if (index == 0 && lane <= max_num_threads) { + LOG(FATAL) << "Error! lane is less equal than max_num_threads, Please check!"; + } + if (lane >= max_num_threads / 2) { + if (lane <= max_num_threads) { + --index; + } + break; + } + } + std::vector first_axes(axes.begin(), axes.begin() + index + 1); + if (lane > max_num_threads) { + // last reduce axis size > 1024 + if (index == static_cast(axes.size()) - 1) { + int idx = max_num_threads; + do { + if (lane % idx == 0) { + ir_sch.Split(block_name, axes[index], {-1, idx}); + break; + } + --idx; + } while (idx >= max_num_threads / 2); + // if can't be divide by(1024, 512), it's shouldn't be fused. + CHECK_GE(idx, max_num_threads / 2) << "Check bounds exist, can't fuse!"; + } else { + int axis = axes[index]; + int prefix = inshape[axis]; + int tail = lane / prefix; + for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; --idx) { + if (prefix % idx == 0) { + ir_sch.Split(block_name, axis, {-1, idx}); + break; + } + CHECK_GT(idx, (max_num_threads / 2) / tail) << "Error, it's shouldn't fuse!"; + } + } + LoopOrderAssignReduce(ir_sch, block_name, first_axes, target); + } else { + int fuse_times = axes.size() - (index + 1) - 1; + for (int idx = 0; idx < fuse_times; ++idx) { + ir_sch.Fuse(block_name, {axes[index + 1], axes[index + 1] + 1}); + } + LoopOrderAssignReduce(ir_sch, block_name, first_axes, target, true); + // fuse axis before reduce to bind blockidx. + for (int idx = 0; idx < int(inshape.size() - axes.size()) - 1; ++idx) { + ir_sch.Fuse(block_name, {0, 1}); + } + } +} + +bool CanbeInline(Node* node, + const std::vector consumers, + const Node* reducer, + const Node* laster, + const GroupPtr& group, + const std::unordered_set& nodes_set, + const absl::flat_hash_map& shape_dict) { + if (group->output_nodes.count(node)) { + return false; + } + if (IsConstOp(node)) { + return true; + } + + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + for (auto consumer : consumers) { + if (op_pattern_dict[consumer->op()] == framework::kReduction) { + return false; + } + } + + if (op_pattern_dict[node->op()] == framework::kReduction) { + return false; + } + + if (consumers.size() == 1) { + return true; + } + + if (reducer) { + // node is before reducer and node is not after reduce. + if (FindReducerInRoute(node, nodes_set, GetConsumersInSet) && + !FindReducerInRoute(node, nodes_set, GetProducersInSet)) { + auto node_shape = GetOutputShape(node, shape_dict); + auto input_shape = GetInputShape(reducer, shape_dict); + // check with same shape with reducer input. + if (std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()) != + std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies())) { + return true; + } + } + + return false; + } else { + auto node_shape = GetOutputShape(node, shape_dict); + auto last_shape = GetOutputShape(laster, shape_dict); + if (std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()) != + std::accumulate(last_shape.begin(), last_shape.end(), 1, std::multiplies())) { + return true; + } + + return false; + } +} + +Node* GetMasterToComputeAt(Node* node, + const std::vector& nodes_in_order, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set, + const std::unordered_map& virtual_consumers, + const absl::flat_hash_map& shape_dict) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + // if node is reduction, try find horizontal to compute at. + if (op_pattern_dict[node->op()] == framework::kReduction) { + // find all reduce node has done schedule. + std::unordered_set done_schedule; + for (auto tmp : nodes_in_order) { + if (tmp == node) { + break; + } + if (op_pattern_dict[tmp->op()] == framework::kReduction) { + done_schedule.insert(tmp); + } + } + // remove all consuemr reducer node of node from done_schedule. + std::unordered_set visited; + std::queue candidates; + candidates.push(node); + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + for (auto consumer : GetConsumersInSet(candidate, nodes_set)) { + // remove reduction node from done_schedule. + if (op_pattern_dict[consumer->op()] == framework::kReduction) { + done_schedule.erase(consumer); + } + if (visited.count(consumer)) { + continue; + } + candidates.push(consumer); + visited.insert(consumer); + } + } + + if (done_schedule.size()) { + auto shape = shape_dict.at(node->inlinks_in_order()[0]->source()->id()); + for (auto rnode : done_schedule) { + auto rshape = shape_dict.at(rnode->inlinks_in_order()[0]->source()->id()); + if (shape == rshape) { + return rnode; + } + } + return *done_schedule.begin(); + } + } + + // collect all consumers. + std::unordered_set visited, masters; + std::queue candidates; + candidates.push(node); + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + auto consumers = FindConsumers(candidate, nodes_set, virtual_consumers); + for (auto consumer : consumers) { + if (visited.count(consumer)) { + continue; + } + if (nodes_inline.count(consumer)) { + candidates.push(consumer); + visited.insert(consumer); + } else { + masters.insert(consumer); + } + } + } + + // nodes-in-order + for (int idx = 0; idx < nodes_in_order.size(); ++idx) { + if (nodes_in_order[idx] == node) { + for (int idy = idx - 1; idy >= 0; --idy) { + if (masters.count(nodes_in_order[idy])) { + return nodes_in_order[idy]; + } + } + break; + } + } + return nullptr; +} + +void LoopAssignReduce(ir::IRSchedule& ir_sch, + const Node* node, + const Node* reducer, + const Target& target, + const std::unordered_map& tensor_map, + const absl::flat_hash_map& shape_dict) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + // if node is reducer, return. + if (op_pattern_dict[node->op()] == framework::kReduction) { + return; + } + auto node_data = GetNodeData(node); + auto reducer_data = GetNodeData(reducer); + + // flatten loops. + auto loops = ir_sch.GetLoops(node_data->id()); + // do loop flatten. + if (op_pattern_dict[node->op()] == framework::kElementWise) { + ir_sch.FlattenLoops(loops, true); + } else { + ir_sch.FlattenLoops(loops, false); + } + + // shape and axis. + CHECK(shape_dict.count(reducer->inlinks_in_order()[0]->source()->id())); + auto shape = shape_dict.at(reducer->inlinks_in_order()[0]->source()->id()); + auto axes = absl::get>(reducer->attrs.attr_store.at("dim")); + if (axes.empty()) { + for (int idx = 0; idx < shape.size(); idx++) { + axes.push_back(idx); + } + } + + auto copy_loop_info = [](std::vector& loops, std::vector& rloops) { + for (int idx = 0; idx < std::min(rloops.size(), loops.size()); ++idx) { + auto l0 = rloops[idx].As(); + auto l1 = loops[idx].As(); + l1->set_for_type(l0->for_type()); + l1->set_bind_info(l0->bind_info()); + } + }; + + auto node_shape = shape_dict.at(node_data->id()); + // node output is same shape with reduce output. + if (std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) != + std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies())) { + // split loop to assign master loop + int extend = 1; + std::vector factors; + loops = ir_sch.GetLoops(node_data->id()); + auto rloops = ir_sch.GetLoops(reducer_data->id()); + + for (auto& loop : rloops) { + extend *= loop.As()->extent.as_int32(); + if (extend > loops.back().As()->extent.as_int32()) { + break; + } + CHECK_LE(extend, loops.back().As()->extent.as_int32()); + factors.push_back(loop.As()->extent.as_int32()); + } + + ir_sch.Split(loops.back(), factors); + loops = ir_sch.GetLoops(node_data->id()); + // copy loop info form rloops. + copy_loop_info(loops, rloops); + return; + } + + // node output is same shape with reduce input. + if (WithoutLastDimInReduce(shape, axes)) { + auto nloops = ir_sch.GetLoops(node_data->id()); + ir_sch.Split(nloops.back(), shape); + // if using block shuffle + if (tensor_map.count(reducer_data->id() + "_1")) { + LoopAssignReduceWithoutLast(ir_sch, node_data->id(), shape, axes, target); + auto nloops = ir_sch.GetLoops(node_data->id()); + auto rloops = ir_sch.GetLoops(tensor_map.find(reducer_data->id() + "_0")->second->name); + if (nloops.size() < rloops.size()) { + ir_sch.Split(nloops[0], {1, -1}); + } + + nloops = ir_sch.GetLoops(node_data->id()); + // copy loop info form rloops. + copy_loop_info(nloops, rloops); + } else { + LoopOrderAssignReduce(ir_sch, node_data->id(), axes, target); + auto nloops = ir_sch.GetLoops(node_data->id()); + auto rloops = ir_sch.GetLoops(tensor_map.find(reducer_data->id())->second->name); + if (nloops.size() < rloops.size()) { + ir_sch.Split(nloops[0], {1, -1}); + } + + nloops = ir_sch.GetLoops(node_data->id()); + // copy loop info form rloops. + copy_loop_info(nloops, rloops); + } + } else { + if (tensor_map.count(reducer_data->id() + "_1")) { + { + auto nloops = ir_sch.GetLoops(node_data->id()); + ir_sch.Split(nloops.back(), shape); + } + LoopAssignReduceWithLast(ir_sch, node_data->id(), shape, axes, target); + + auto nloops = ir_sch.GetLoops(node_data->id()); + auto rloops = ir_sch.GetLoops(tensor_map.find(reducer_data->id() + "_1")->second->name); + if (nloops.size() < rloops.size()) { + ir_sch.Split(nloops[0], {1, -1}); + } + + nloops = ir_sch.GetLoops(node_data->id()); + // copy loop info form rloops. + copy_loop_info(nloops, rloops); + } else if (tensor_map.count(reducer_data->id() + "_0")) { + auto tensor = tensor_map.find(reducer_data->id() + "_0")->second; + auto rloops = ir_sch.GetLoops(tensor->name); + std::vector factors; + for (auto& loop : rloops) { + factors.push_back(loop.As()->extent.as_int32()); + } + auto nloops = ir_sch.GetLoops(node_data->id()); + ir_sch.Split(nloops.back(), factors); + + nloops = ir_sch.GetLoops(node_data->id()); + // copy loop info form rloops. + copy_loop_info(nloops, rloops); + } else { + LOG(FATAL) << "Error! Unkown Reduce Type!"; + } + } +} + +// The struct used to remove the original block in ComputeAt. +class RemoveExpr : public ir::IRMutator<> { + public: + RemoveExpr(const Expr& target) : target_(target) {} + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + void Visit(const ir::ScheduleBlockRealize* expr, Expr* op) override { IRMutator::Visit(expr, op); } + + void Visit(const ir::For* expr, Expr* op) override { IRMutator::Visit(expr, op); } + + void Visit(const ir::Block* expr, Expr* op) override { + auto* node = op->As(); + auto iter = std::find(node->stmts.begin(), node->stmts.end(), target_); + if (iter != node->stmts.end()) { + node->stmts.erase(iter); + } else { + for (auto stmt : node->stmts) { + IRMutator::Visit(&stmt, &stmt); + } + } + } + + private: + const Expr& target_; +}; + +void MergeLoops(ir::Expr root, std::vector& src, std::vector& dst, int index) { + if (index < 0) { + return; + } + CHECK_GT(src.size(), index) << "\nindex -> " << index << "\n" << src[0]; + CHECK_GT(dst.size(), index) << "\nindex -> " << index << "\n" << dst[0]; + + if (src[0] == dst[0]) { + return; + } + + std::vector src_vars; + std::vector dst_vars; + for (int idx = 0; idx <= index; ++idx) { + src_vars.push_back(src[idx].As()->loop_var); + dst_vars.push_back(ir::Expr(dst[idx].As()->loop_var)); + } + + auto src_body = src[index].As()->body; + ReplaceExpr(&src_body, src_vars, dst_vars); + dst[index].As()->body = ir::Block::Make({src_body, dst[index].As()->body}); + + RemoveExpr remove_expr(src[0]); + remove_expr(&root); +} + +void InsertSyncThread(ir::IRSchedule& ir_sch, + const Node* node, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { + CHECK(shape_dict.count(node->inlinks_in_order()[0]->source()->id())); + auto shape = shape_dict.at(node->inlinks_in_order()[0]->source()->id()); + auto axes = absl::get>(node->attrs.attr_store.at("dim")); + if (axes.empty()) { + for (int idx = 0; idx < shape.size(); idx++) { + axes.push_back(idx); + } + } + if (!WithoutLastDimInReduce(shape, axes)) { + return; + } + + auto node_data = GetNodeData(node); + std::string post = ""; + for (int idx = 0;; ++idx) { + if (!tensor_map.count(node_data->id() + post)) { + break; + } + auto tensor = tensor_map.find(node_data->id() + post)->second; + if (!ir_sch.HasBlock(tensor->name)) { + break; + } + + post = "_" + std::to_string(idx); + if (idx > 0) { + // insert syncthreads. + auto loops = ir_sch.GetLoops(node_data->id()); + ir_sch.SyncThreads(loops.back(), false); + return; + } + } +} + +// The struct used to remove the original block in ComputeAt. +class InsertExpr : public ir::IRMutator<> { + public: + InsertExpr(Expr& target, Expr& anchor) : target_(target), anchor_(anchor) {} + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + void Visit(const ir::ScheduleBlockRealize* expr, Expr* op) override { IRMutator::Visit(expr, op); } + + void Visit(const ir::For* expr, Expr* op) override { IRMutator::Visit(expr, op); } + + void Visit(const ir::Block* expr, Expr* op) override { + auto* node = op->As(); + auto iter = std::find(node->stmts.begin(), node->stmts.end(), anchor_); + if (iter != node->stmts.end()) { + node->stmts.insert(iter, target_); + } else { + for (auto stmt : node->stmts) { + IRMutator::Visit(&stmt, &stmt); + } + } + } + + private: + Expr target_; + Expr anchor_; +}; + +void MergeReduceToReduce(ir::IRSchedule& ir_sch, + const Node* node, + const Node* master, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { + auto node_data = GetNodeData(node); + auto master_data = GetNodeData(master); + + CHECK(shape_dict.count(node->inlinks_in_order()[0]->source()->id())); + auto shape = shape_dict.at(node->inlinks_in_order()[0]->source()->id()); + auto axes = absl::get>(node->attrs.attr_store.at("dim")); + if (axes.empty()) { + for (int idx = 0; idx < shape.size(); idx++) { + axes.push_back(idx); + } + } + if (WithoutLastDimInReduce(shape, axes)) { + auto mshape = shape_dict.at(master->inlinks_in_order()[0]->source()->id()); + // using block shuffle + if (tensor_map.count(node_data->id() + "_1")) { + if (shape == mshape) { + // block shuffle + { + auto block = ir_sch.GetBlock(node_data->id()); + auto loops = ir_sch.GetLoops(master_data->id()); + ir_sch.SimpleComputeAt(block, loops.back()); + } + // reduce loop + { + auto n_tensor = tensor_map.find(node_data->id() + "_0")->second; + auto m_tensor = tensor_map.find(master_data->id() + "_0")->second; + + auto block = ir_sch.GetBlock(n_tensor->name); + auto loops = ir_sch.GetLoops(m_tensor->name); + ir_sch.SimpleComputeAt(block, loops.back()); + // reduce init + { + auto block = ir_sch.GetBlock(n_tensor->name + "__reduce_init"); + auto loops = ir_sch.GetLoops(m_tensor->name + "__reduce_init"); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } + } else { + auto n_tensor = tensor_map.find(node_data->id() + "_0")->second; + auto m_tensor = tensor_map.find(master_data->id() + "_0")->second; + if (n_tensor->shape.back() == m_tensor->shape.back()) { + // block shuffle + { + auto block = ir_sch.GetBlock(node_data->id()); + auto loops = ir_sch.GetLoops(master_data->id()); + ir_sch.SimpleComputeAt(block, loops.back()); + } + // reduce loop + { + auto n_tensor = tensor_map.find(node_data->id() + "_0")->second; + auto m_tensor = tensor_map.find(master_data->id() + "_0")->second; + + auto n_block = ir_sch.GetBlock(n_tensor->name); + auto m_block = ir_sch.GetBlock(m_tensor->name); + + auto n_loops = ir_sch.GetLoops(n_tensor->name); + auto m_loops = ir_sch.GetLoops(m_tensor->name); + + std::vector src_vars; + std::vector dst_vars; + for (int idx = 0; idx < m_loops.size() - 1; ++idx) { + src_vars.push_back(n_loops[idx].As()->loop_var); + dst_vars.push_back(ir::Expr(m_loops[idx].As()->loop_var)); + } + ReplaceExpr(&n_block, src_vars, dst_vars); + + int index = n_loops.size(); + InsertExpr insert_expr(n_loops[index - 1], m_loops[index - 1]); + insert_expr(&m_loops[0]); + + // reduce init + { + auto block = ir_sch.GetBlock(n_tensor->name + "__reduce_init"); + auto loops = ir_sch.GetLoops(m_tensor->name + "__reduce_init"); + ir_sch.SimpleComputeAt(block, loops.back()); + } + RemoveExpr remove_expr(n_loops[0]); + remove_expr(&ir_sch.GetModule().GetExprs().at(0)); + } + } else { + // block shuffle + { + auto block = ir_sch.GetBlock(node_data->id()); + auto loops = ir_sch.GetLoops(master_data->id()); + ir_sch.SimpleComputeAt(block, loops.back()); + } + // reducer loop + { + auto n_tensor = tensor_map.find(node_data->id() + "_0")->second; + auto m_tensor = tensor_map.find(master_data->id() + "_0")->second; + + auto n_loops = ir_sch.GetLoops(n_tensor->name); + auto m_loops = ir_sch.GetLoops(m_tensor->name); + + MergeLoops(ir_sch.GetModule().GetExprs().at(0), n_loops, m_loops, 0); + } + } + } + } else { + if (shape == mshape) { + // reduce loop + { + auto block = ir_sch.GetBlock(node_data->id()); + auto loops = ir_sch.GetLoops(master_data->id()); + ir_sch.SimpleComputeAt(block, loops.back()); + // reduce init + { + auto block = ir_sch.GetBlock(node_data->id() + "__reduce_init"); + auto loops = ir_sch.GetLoops(master_data->id() + "__reduce_init"); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } + } else { + // reduce loop + { + auto block = ir_sch.GetBlock(node_data->id()); + auto nloops = ir_sch.GetLoops(node_data->id()); + auto mloops = ir_sch.GetLoops(master_data->id()); + for (int idx = 0; idx < mloops.size(); ++idx) { + if (GetLoopExtent(nloops[idx]) != GetLoopExtent(mloops[idx])) { + ir_sch.SimpleComputeAt(block, mloops[idx - 1]); + break; + } + } + // reduce init + { + auto block = ir_sch.GetBlock(node_data->id() + "__reduce_init"); + auto loops = ir_sch.GetLoops(master_data->id() + "__reduce_init"); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } + } + } + } else { + if (tensor_map.count(node_data->id() + "_1")) { + // identity + { + auto block = ir_sch.GetBlock(node_data->id()); + auto loops = ir_sch.GetLoops(master_data->id()); + ir_sch.SimpleComputeAt(block, loops.back()); + } + // reduce + { + auto n_tensor = tensor_map.find(node_data->id() + "_1")->second; + auto m_tensor = tensor_map.find(master_data->id() + "_1")->second; + + auto block = ir_sch.GetBlock(n_tensor->name); + auto loops = ir_sch.GetLoops(m_tensor->name); + ir_sch.SimpleComputeAt(block, loops.back()); + // reduce init + { + auto block = ir_sch.GetBlock(n_tensor->name + "__reduce_init"); + auto loops = ir_sch.GetLoops(m_tensor->name + "__reduce_init"); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } + // block shuffle + { + auto n_tensor = tensor_map.find(node_data->id() + "_0")->second; + auto m_tensor = tensor_map.find(master_data->id() + "_0")->second; + + auto n_block = ir_sch.GetBlock(n_tensor->name); + auto m_block = ir_sch.GetBlock(m_tensor->name); + + auto n_loops = ir_sch.GetLoops(n_tensor->name); + auto m_loops = ir_sch.GetLoops(m_tensor->name); + + std::vector src_vars; + std::vector dst_vars; + for (int idx = 0; idx < m_loops.size(); ++idx) { + src_vars.push_back(n_loops[idx].As()->loop_var); + dst_vars.push_back(ir::Expr(m_loops[idx].As()->loop_var)); + } + ReplaceExpr(&n_block, src_vars, dst_vars); + + InsertExpr insert_expr(n_block, m_block); + insert_expr(&m_loops.back()); + + RemoveExpr remove_expr(n_loops[0]); + remove_expr(&ir_sch.GetModule().GetExprs().at(0)); + } + } else if (tensor_map.count(node_data->id() + "_0")) { + // identity + { + auto block = ir_sch.GetBlock(node_data->id()); + auto loops = ir_sch.GetLoops(master_data->id()); + ir_sch.SimpleComputeAt(block, loops.back()); + } + // shuffle reduce + { + auto n_tensor = tensor_map.find(node_data->id() + "_0")->second; + auto m_tensor = tensor_map.find(master_data->id() + "_0")->second; + + auto block = ir_sch.GetBlock(n_tensor->name); + auto loops = ir_sch.GetLoops(m_tensor->name); + ir_sch.SimpleComputeAt(block, loops.back()); + } + } else { + LOG(FATAL) << "Error! Unkown Reduce Type, Please Check!"; + } + } +} + +void MergeReduceLoop(ir::IRSchedule& ir_sch, + Node* node, + const Node* master, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + if (op_pattern_dict[master->op()] == kReduction && node != master) { + MergeReduceToReduce(ir_sch, node, master, shape_dict, tensor_map); + return; + } + + auto node_data = GetNodeData(node); + auto master_data = GetNodeData(master); + + int min_index_loop = INT_MAX; + std::string post_ = "", post__ = "_0"; + for (int idx = 0;; ++idx) { + if (!tensor_map.count(node_data->id() + post__)) { + break; + } + auto tensor_ = tensor_map.find(node_data->id() + post_)->second; + auto tensor__ = tensor_map.find(node_data->id() + post__)->second; + if (!ir_sch.HasBlock(tensor__->name)) { + break; + } + + auto dst_loops = ir_sch.GetLoops(tensor_->name); + auto src_loops = ir_sch.GetLoops(tensor__->name); + int index = -1; + while (src_loops[index + 1].As()->extent.as_int32() == + dst_loops[index + 1].As()->extent.as_int32()) { + ++index; + if (src_loops.size() == index + 1 || dst_loops.size() == index + 1) { + break; + } + } + min_index_loop = std::min(min_index_loop, index); + MergeLoops(ir_sch.GetModule().GetExprs().at(0), src_loops, dst_loops, index); + + post_ = "_" + std::to_string(idx); + post__ = "_" + std::to_string(idx + 1); + } + InsertSyncThread(ir_sch, node, shape_dict, tensor_map); + + if (node == master) return; + auto node_loops = ir_sch.GetLoops(node_data->id()); + auto master_loops = ir_sch.GetLoops(master_data->id()); + + int index = std::min(node_loops.size(), master_loops.size()) - 1; + do { + // if loop range is not equal. + if (node_loops[index].As()->extent.as_int32() != master_loops[index].As()->extent.as_int32()) { + continue; + } + + MergeLoops(ir_sch.GetModule().GetExprs().at(0), node_loops, master_loops, std::min(index, min_index_loop)); + if (index > min_index_loop) { + auto block = ir_sch.GetBlock(node_data->id()); + auto loops = ir_sch.GetLoops(master_data->id()); + ir_sch.SimpleComputeAt(block, loops.back()); + } + + break; + } while (--index >= 0); +} + +void LoopComputeAt(ir::IRSchedule& ir_sch, + Node* node, + const Node* master, + const GroupPtr& group, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + if (!group->output_nodes.count(node)) { + auto block = ir_sch.GetBlock(GetNodeData(node)->id()); + ir_sch.SetBuffer(block, "local", true); + } + + if (op_pattern_dict[node->op()] == framework::kReduction) { + MergeReduceLoop(ir_sch, node, master, shape_dict, tensor_map); + return; + } + + if (node == master) return; + + auto node_data = GetNodeData(node); + auto master_data = GetNodeData(master); + + auto node_loops = ir_sch.GetLoops(node_data->id()); + auto master_loops = ir_sch.GetLoops(master_data->id()); + + if (op_pattern_dict[master->op()] == framework::kReduction) { + // find real master loops. + std::string prefix = "", post = ""; + for (int idx = 0;; ++idx) { + if (!tensor_map.count(master_data->id() + post)) { + break; + } + auto tensor = tensor_map.find(master_data->id() + post)->second; + if (!ir_sch.HasBlock(tensor->name)) { + break; + } + + prefix = post; + post = "_" + std::to_string(idx); + } + + auto tensor = tensor_map.find(master_data->id() + prefix)->second; + master_loops = ir_sch.GetLoops(tensor->name); + } + + int index = std::min(node_loops.size(), master_loops.size()) - 1; + do { + // if loop range is not equal. + if (node_loops[index].As()->extent.as_int32() != master_loops[index].As()->extent.as_int32()) { + continue; + } + + MergeLoops(ir_sch.GetModule().GetExprs().at(0), node_loops, master_loops, index); + break; + } while (--index >= 0); +} + +std::unordered_map GetNodeDataSet(const std::unordered_set& nodes_set) { + std::unordered_map node_data_set; + for (auto node : nodes_set) { + auto node_data = GetNodeData(node); + node_data_set[node_data->id()] = node_data; + } + return node_data_set; +} + +Node* GetMaster(Node* node, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set) { + // find consumer + std::unordered_set visited; + std::queue candidates; + candidates.push(node); + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + auto consumers = GetConsumersInSet(candidate, nodes_set); + for (auto consumer : consumers) { + if (visited.count(consumer)) { + continue; + } + if (nodes_inline.count(consumer)) { + candidates.push(consumer); + visited.insert(consumer); + } else { + return consumer; + } + } + } + + return nullptr; +} + +void SyncThreadWithShared(ir::IRSchedule& ir_sch, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { + auto exprs_inorder = ir_sch.GetAllBlocks(); + auto node_data_set = GetNodeDataSet(nodes_set); + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + + std::unordered_set sync_mark; + auto check_sync_mark = [&](const int start, const std::string& m_id) { + for (int idx = start + 1; exprs_inorder.size(); ++idx) { + auto expr = exprs_inorder[idx]; + CHECK(expr.As()); + CHECK(expr.As()->schedule_block.As()); + auto block = expr.As()->schedule_block.As(); + + if (sync_mark.count(block->name)) { + return false; + } + + if (block->name == m_id) { + return true; + } + } + return false; + }; + + for (int idx = 0; idx < exprs_inorder.size() - 1; ++idx) { + auto expr = exprs_inorder[idx]; + CHECK(expr.As()); + CHECK(expr.As()->schedule_block.As()); + auto block = expr.As()->schedule_block.As(); + + if (!node_data_set.count(block->name)) { + continue; + } + auto node_data = node_data_set.find(block->name)->second; + auto node = node_data->source_node.get(); + auto node_shape = shape_dict.at(node_data->id()); + + auto master = GetMaster(node, nodes_inline, nodes_set); + if (!master) { + continue; + } + + auto master_data = GetNodeData(master); + auto master_shape = shape_dict.at(master_data->id()); + if (op_pattern_dict[master->op()] == framework::kReduction) { + master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id()); + } + + auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); + auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); + + if (node_size == master_size) { + continue; + } + + { + auto block = ir_sch.GetBlock(node_data->id()); + ir_sch.SetBuffer(block, "shared", true); + } + + if (check_sync_mark(idx, master_data->id())) { + auto loops = ir_sch.GetLoops(master_data->id()); + ir_sch.SyncThreads(loops.back(), false); + sync_mark.insert(master_data->id()); + } + } +} + +void DoReduceInline(ir::IRSchedule& ir_sch, + Node* node, + bool caninline, + std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { + auto master = GetMaster(node, nodes_inline, nodes_set); + if (!master) { + return; + } + + auto node_data = GetNodeData(node); + auto master_data = GetNodeData(master); + if (caninline && tensor_map.count(node_data->id() + "_0")) { + auto block = ir_sch.GetBlock(node_data->id()); + ir_sch.ComputeInline(block); + nodes_inline.insert(node); + } +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/op_lowering_util.h b/cinn/hlir/framework/op_lowering_util.h new file mode 100644 index 0000000000..0854723b85 --- /dev/null +++ b/cinn/hlir/framework/op_lowering_util.h @@ -0,0 +1,97 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/hlir/framework/op_lowering.h" + +namespace cinn { +namespace hlir { +namespace framework { + +std::vector GetInputNodeData(const Node* node); + +ir::Tensor GetTensor(const NodeData* node_data, + const absl::flat_hash_map& type_dict, + const absl::flat_hash_map& shape_dict); + +std::vector CollectInputTensor(const Node* node, + std::vector& func_args, + std::unordered_map& tensor_map, + const absl::flat_hash_map& type_dict, + const absl::flat_hash_map& shape_dict); + +std::unordered_map BuildVirtualConsumer(const GroupPtr& group, + const absl::flat_hash_map& shape_dict); + +NodeData* GetNodeData(const Node* node); + +std::vector GetAllNodeData(const Node* node); + +std::vector GetConsumersInSet(const Node* node, const std::unordered_set& node_set); + +std::vector TopologicalOrder(const GroupPtr& group, const std::unordered_map& virtual_consumers); + +Node* FindGlobalReducer(const std::vector& nodes_in_order); + +Node* FindNearestReducer(const Node* node, const std::unordered_set& nodes_set); + +bool CanbeInline(Node* node, + const std::vector consumers, + const Node* reducer, + const Node* laster, + const GroupPtr& group, + const std::unordered_set& nodes_set, + const absl::flat_hash_map& shape_dict); + +Node* GetMasterToComputeAt(Node* node, + const std::vector& nodes_in_order, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set, + const std::unordered_map& virtual_consumers, + const absl::flat_hash_map& shape_dict); + +void LoopAssignReduce(ir::IRSchedule& ir_sch, + const Node* node, + const Node* reducer, + const Target& target, + const std::unordered_map& tensor_map, + const absl::flat_hash_map& shape_dict); + +void LoopComputeAt(ir::IRSchedule& ir_sch, + Node* node, + const Node* master, + const GroupPtr& group, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map); + +void SyncThreadWithShared(ir::IRSchedule& ir_sch, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map); + +void DoReduceInline(ir::IRSchedule& ir_sch, + Node* node, + bool caninline, + std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map); + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/reduction_test.cc b/cinn/hlir/op/reduction_test.cc index 5a05169a17..22fdfb302c 100755 --- a/cinn/hlir/op/reduction_test.cc +++ b/cinn/hlir/op/reduction_test.cc @@ -144,6 +144,13 @@ std::pair GenReduceCode(const std::vector& shape, return std::pair(host_module, source_code); } +TEST(Operator, Operator_Reduction_Case_Last_Dim_1) { + std::vector shape = {10, 100, 1}; + std::vector dim = {0, 2}; + + GenReduceCode(shape, dim, "reduce_cast_with_last_1"); +} + TEST(Operator, Operator_Reduction_Case_0) { std::vector shape = {16, 16, 8, 16}; std::vector dim = {2, 3}; diff --git a/cinn/hlir/pass/const_propagate_test.cc b/cinn/hlir/pass/const_propagate_test.cc index c352eee4ca..3d73e1356d 100644 --- a/cinn/hlir/pass/const_propagate_test.cc +++ b/cinn/hlir/pass/const_propagate_test.cc @@ -102,10 +102,9 @@ TEST(const_bn, const_bn) { hlir::framework::GraphCompiler gc(target, scope, graph); auto runtime_program = gc.Build(); - auto& prerun_instrs = runtime_program->GetPreRunInstructions(); auto& run_instrs = runtime_program->GetRunInstructions(); // Revert changes in PR #990 to pass the model unittests - ASSERT_EQ(run_instrs.size(), 2); + ASSERT_EQ(run_instrs.size(), 1); scope->Var("A"); scope->Var("Scale"); diff --git a/cinn/hlir/pass/fusion_helper_base.h b/cinn/hlir/pass/fusion_helper_base.h index d17555cb65..d4a0fc56ce 100644 --- a/cinn/hlir/pass/fusion_helper_base.h +++ b/cinn/hlir/pass/fusion_helper_base.h @@ -80,6 +80,13 @@ class FusionHelperBase { return shape_dict_.at(node_data->id()); } + shape_t GetNodeInputShape(const Node* node) const { + auto node_datas = GetProducerNodeData(node); + CHECK_GT(node_datas.size(), 0); + CHECK(shape_dict_.count(node_datas[0]->id())) << "Can't find " << node_datas[0]->id() << " 's shape!"; + return shape_dict_.at(node_datas[0]->id()); + } + static std::vector GetProducerNodeData(const Node* node) { std::vector producer_node_data; for (auto& edge : node->inlinks_in_order(true)) { diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 1c7c4e1355..37424855c1 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -616,14 +616,22 @@ class FusionMergePassHelper : public FusionHelperBase { candidates.insert(consumer); } else { VLOG(4) << "Fuse Producer : " << producer->group_id << " into Consumer : " << consumer->group_id; + auto& sub_group = consumer->fused_sub_groups.front(); + auto constant_node = producer->CollectNodes()[0]; + if (sub_group->NodeSet().count(constant_node)) { + // remove depency. + consumer->input_nodes.erase(constant_node); + consumer->producer_groups.erase(producer); + producer->consumer_groups.erase(consumer); + continue; + } consumer->group_id = producer->group_id + "_" + consumer->group_id; // just merge the node into group. - auto& sub_group = consumer->fused_sub_groups.front(); sub_group->group_id = producer->group_id + "_" + sub_group->group_id; - sub_group->nodes.insert(sub_group->nodes.begin(), producer->CollectNodes()[0]); - sub_group->nodes_set.insert(producer->CollectNodes()[0]); + sub_group->nodes.insert(sub_group->nodes.begin(), constant_node); + sub_group->nodes_set.insert(constant_node); // remove depency. - consumer->input_nodes.erase(producer->CollectNodes()[0]); + consumer->input_nodes.erase(constant_node); consumer->producer_groups.erase(producer); producer->consumer_groups.erase(consumer); } diff --git a/cinn/hlir/pass/fusion_merge_pass_test.cc b/cinn/hlir/pass/fusion_merge_pass_test.cc index eabd712c3f..eb5e0dac8d 100755 --- a/cinn/hlir/pass/fusion_merge_pass_test.cc +++ b/cinn/hlir/pass/fusion_merge_pass_test.cc @@ -199,7 +199,7 @@ TEST(FusionMergePass, Broadcast_Test_0) { auto graph = std::make_shared(program, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 2); + CHECK_EQ(graph->fusion_groups.size(), 1); hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); CHECK_EQ(graph->fusion_groups.size(), 1); } diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index ba42028d57..23f9c5ed8d 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -223,12 +223,7 @@ class OpFusionPassHelper : public FusionHelperBase { if (is_same_size(helper, producer, consumer)) { return true; } - - if (helper->IsConstOp(producer) && !helper->output_nodes_set_.count(producer)) { - return true; - } - - return false; + return !helper->output_nodes_set_.count(producer); }}, // horizontal or vertical relation, check with same output shape with horizontal relation or with last // successive dimension less than 1024 for gpu. @@ -249,7 +244,13 @@ class OpFusionPassHelper : public FusionHelperBase { // horizontal or vertical relation(Broadcast + *Elementwise*), check with same output shape. {framework::kElementWise, is_same_size}, // must be horizontal, as Broadcast + Broadcast is not allowed. - {framework::kBroadcast, is_same_size}, + {framework::kBroadcast, + [](const FusionHelperBase* helper, const Node* producer, const GroupPtr& consumer) -> bool { + if (is_same_size(helper, producer, consumer)) { + return true; + } + return !helper->output_nodes_set_.count(producer); + }}, // horizontal or vertical relation(Broadcast + Reduce). {framework::kReduction, horizontal_or_vertical_reduce_relation}, // can be horizontal or can compute inline, check with same output shape or just one consumer. @@ -262,17 +263,13 @@ class OpFusionPassHelper : public FusionHelperBase { { FusionRelation relation; // producer -> consumer - relation.op_kind = {framework::kElementWise}; + relation.op_kind = {framework::kElementWise, framework::kBroadcast}; // producer -> fusion relation.fusion_op_kind = { // horizontal or vertical relation(Reduce + Elementwise*), check without last dimension in reduce. {framework::kElementWise, without_last_dimension_in_reduce}, // must be horizontal relation, check with same output shape and without last dimension in reduce. - {framework::kBroadcast, - [](const FusionHelperBase* helper, const Node* producer, const GroupPtr& consumer) -> bool { - return is_same_size(helper, producer, consumer) && - without_last_dimension_in_reduce(helper, producer, consumer); - }}, + {framework::kBroadcast, reduce_fuse_broadcast}, // must be horizontal relation and with same reduce attr. {framework::kReduction, reduce_fuse_reduce}, // no_fuse diff --git a/cinn/hlir/pass/op_fusion_pass_test.cc b/cinn/hlir/pass/op_fusion_pass_test.cc index 28c25a30ab..4ba77dec9c 100755 --- a/cinn/hlir/pass/op_fusion_pass_test.cc +++ b/cinn/hlir/pass/op_fusion_pass_test.cc @@ -111,7 +111,7 @@ TEST(OpFusionPass, Brodcast_Test_1) { auto graph = std::make_shared(program, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 2); + CHECK_EQ(graph->fusion_groups.size(), 1); } TEST(OpFusionPass, Brodcast_Test_2) { @@ -131,7 +131,7 @@ TEST(OpFusionPass, Brodcast_Test_2) { auto graph = std::make_shared(program, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 2); + CHECK_EQ(graph->fusion_groups.size(), 1); } TEST(OpFusionPass, Reduce_Test_0) { diff --git a/cinn/hlir/pass/op_fusion_pass_util.h b/cinn/hlir/pass/op_fusion_pass_util.h index e231296513..79ddcd2dda 100644 --- a/cinn/hlir/pass/op_fusion_pass_util.h +++ b/cinn/hlir/pass/op_fusion_pass_util.h @@ -251,6 +251,105 @@ CONDITION_FUNC(horizontal_with_same_size) { return false; } +CONDITION_FUNC(reduce_fuse_broadcast) { + if (horizontal_with_same_size(helper, producer, consumer)) { + return true; + } + + if (is_horizontal_relation(helper, producer, consumer)) { + return false; + } + + if (helper->target_ != common::DefaultNVGPUTarget()) { + return true; + } + + auto rinput_shape = helper->GetNodeInputShape(producer); + auto reduce_axes = absl::get>(producer->attrs.attr_store.at("dim")); + auto keep_dim = absl::get(producer->attrs.attr_store.at("keep_dim")); + for (auto& axis : reduce_axes) { + if (axis == -1) { + axis = rinput_shape.size() - 1; + } + } + + int reduce_size = rinput_shape.back(); + for (auto idx = reduce_axes.size() - 1; idx >= 1; --idx) { + if (reduce_axes[idx] != reduce_axes[idx - 1] + 1) { + return false; + } + reduce_size *= rinput_shape[idx - 1]; + } + + if (reduce_size > helper->target_.max_num_threads()) { + return false; + } + + auto routput_shape = helper->GetNodeDataShape(producer); + auto find_reducer = [&](const Node* node, const Node* reducer, const std::unordered_set& nodes_set) { + std::queue candidates; + candidates.push(node); + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + for (auto producer : helper->GetProducerNode(candidate)) { + if (producer == reducer) { + return true; + } + + if (nodes_set.count(producer)) { + candidates.push(producer); + } + } + } + + return false; + }; + + for (auto node : consumer->nodes_set) { + if (helper->GetOpKind(node) != kBroadcast) { + continue; + } + + if (!find_reducer(node, producer, consumer->nodes_set)) { + continue; + } + + auto shape = absl::get>(node->attrs.attr_store.at("out_shape")); + auto axes = absl::get>(node->attrs.attr_store.at("broadcast_axes")); + for (auto& axis : axes) { + if (axis == -1) { + axis = shape.size() - 1; + } + } + + if (rinput_shape != shape) { + return false; + } + // if keep dim = true. + if (rinput_shape == shape && keep_dim) { + return true; + } else { + // if routput_shape = [1] + if (routput_shape.size() == 1 && routput_shape[0] == 1) { + return true; + } + // check [reduce_axes, axes] = {0, 1, 2, 3, 4, 5, 6, ...} + for (int idx = 0; idx < rinput_shape.size(); ++idx) { + if (!(std::find(axes.begin(), axes.end(), idx) == axes.end()) ^ + std::find(reduce_axes.begin(), reduce_axes.end(), idx) == reduce_axes.end()) { + return false; + } + } + return true; + } + return false; + } + return false; +} + #undef CONDITION_FUNC } // namespace pass diff --git a/cinn/hlir/pe/ir_schedule_pe.cc b/cinn/hlir/pe/ir_schedule_pe.cc index 1a87e67e64..55cdeabd38 100644 --- a/cinn/hlir/pe/ir_schedule_pe.cc +++ b/cinn/hlir/pe/ir_schedule_pe.cc @@ -408,10 +408,12 @@ void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, ir_sch.Bind(loops_tmp_out[0], "blockIdx.x"); ir_sch.Bind(loops_tmp_out[1], "threadIdx.x"); - ir_sch.Bind(loops_out[0], "blockIdx.x"); - if (loops_out.size() > 1) { - ir_sch.Bind(loops_out[1], "threadIdx.x"); + if (loops_out.size() == 1) { + ir_sch.Split(loops_out[0], {-1, 1}); } + loops_out = ir_sch.GetLoops(out->name); + ir_sch.Bind(loops_out[0], "blockIdx.x"); + ir_sch.Bind(loops_out[1], "threadIdx.x"); } for (auto &tensor : {tmp_out}) { @@ -495,25 +497,39 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, ir_sch.Split(loops[0], {-1, ir::GetLoopExtent(loops[0])}); } } - - for (auto &tensor : {reduce_tmp_out, tmp_out}) { - auto loops = ir_sch.GetLoops(tensor->name); - if (loops.size() == 1U) { - ir_sch.Bind(loops[0], "threadIdx.x"); - } else if (loops.size() > 1U) { - ir_sch.Bind(loops[0], "blockIdx.x"); - ir_sch.Bind(loops[1], "threadIdx.x"); + // bind block and thread for reduce. + // reduce_tmp_out + { + auto loops = ir_sch.GetLoops(reduce_tmp_out->name); + if (loops.size() <= 2U) { + if (ir_sch.GetLoops(tmp_out->name).size() == 1) { + ir_sch.Split(loops[0], {-1, 1}); + } + loops = ir_sch.GetLoops(reduce_tmp_out->name); } + ir_sch.Bind(loops[0], "blockIdx.x"); + ir_sch.Bind(loops[1], "threadIdx.x"); } + // tmp_out + { + auto loops = ir_sch.GetLoops(tmp_out->name); + if (loops.size() < 2U) { + ir_sch.Split(loops.back(), {-1, 1}); + loops = ir_sch.GetLoops(tmp_out->name); + } + ir_sch.Bind(loops[0], "blockIdx.x"); + ir_sch.Bind(loops[1], "threadIdx.x"); + } + // out { auto loops = ir_sch.GetLoops(out->name); - if (!loops.empty()) { - ir_sch.Bind(loops[0], "blockIdx.x"); - if (loops.size() > 1U) { - ir_sch.Bind(loops[1], "threadIdx.x"); - } + if (loops.size() < 2U) { + ir_sch.Split(loops.back(), {-1, 1}); + loops = ir_sch.GetLoops(out->name); } + ir_sch.Bind(loops[0], "blockIdx.x"); + ir_sch.Bind(loops[1], "threadIdx.x"); } for (auto &tensor : {reduce_tmp_out, tmp_out}) { @@ -654,10 +670,14 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, for (auto &tensor : {internal, tmp_out, out}) { auto loops = ir_sch.GetLoops(tensor->name); - if (!loops.empty()) ir_sch.Bind(loops[0], "blockIdx.x"); - if (loops.size() > 1) { - ir_sch.Bind(loops[1], "threadIdx.x"); + CHECK(!loops.empty()); + + if (loops.size() == 1) { + ir_sch.Split(loops[0], {-1, 1}); } + loops = ir_sch.GetLoops(tensor->name); + ir_sch.Bind(loops[0], "blockIdx.x"); + ir_sch.Bind(loops[1], "threadIdx.x"); } VLOG(3) << "After IRCudaTwoStepReduceSchedule : " << ir_sch.GetModule().GetExprs().at(0); // ir_sch.SimpleComputeAt(ir_sch.GetBlock(tmp_out->name), ir_sch.GetLoops(out->name)[0]); diff --git a/cinn/ir/ir_schedule.cc b/cinn/ir/ir_schedule.cc index 65d4dcce60..e84f0e17f7 100644 --- a/cinn/ir/ir_schedule.cc +++ b/cinn/ir/ir_schedule.cc @@ -1204,7 +1204,8 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) { Expr source_expr{nullptr}; Expr target_expr{nullptr}; - LeafBlockRemovalPlan remove_plan(result.As() ? result : this_block, &source_expr, &target_expr); + LeafBlockRemovalPlan remove_plan( + result.As() ? block_loops[loops.size()] : this_block, &source_expr, &target_expr); remove_plan(&root); this->Replace(source_expr, target_expr);