From cedf56224f6c0372ee90184e754ed7d2ba00436e Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Thu, 9 Nov 2023 11:28:27 +0000 Subject: [PATCH 1/9] Adapt to pir --- paddle/cinn/adt/CMakeLists.txt | 84 +++++----- paddle/cinn/adt/adapter_dynamic_tensor.h | 23 +-- paddle/cinn/adt/adapter_tensor.cc | 44 +++++ paddle/cinn/adt/adapter_tensor.h | 45 +----- paddle/cinn/adt/equation_value.h | 5 - paddle/cinn/adt/generate_map_expr.cc | 133 +++++++-------- paddle/cinn/adt/generate_map_expr.h | 22 ++- .../cinn/adt/graph_symbolic_dim_infer_ctx.cc | 151 ++++++++---------- .../cinn/adt/graph_symbolic_dim_infer_ctx.h | 45 +++--- paddle/cinn/adt/kgroup.h | 12 +- paddle/cinn/adt/map_expr.h | 21 ++- paddle/cinn/adt/map_expr_ctx.h | 6 +- paddle/cinn/adt/naive_op_equation_context.cc | 62 +++---- paddle/cinn/adt/naive_op_equation_context.h | 18 ++- paddle/cinn/adt/print_equations.cc | 20 ++- paddle/cinn/adt/print_map_expr.cc | 19 ++- paddle/cinn/adt/symbolic_dim_infer_ctx.h | 10 +- paddle/cinn/adt/symbolic_dim_infer_util.cc | 33 ++-- paddle/cinn/adt/symbolic_dim_infer_util.h | 6 +- .../transforms/cinn_group_lowering_pass.cc | 8 + paddle/cinn/hlir/framework/graph.h | 37 ----- .../cinn/hlir/framework/op_lowering_impl.cc | 59 ------- paddle/cinn/hlir/framework/op_lowering_impl.h | 18 --- paddle/cinn/hlir/framework/pir/group.h | 52 +++++- .../hlir/framework/pir/op_lowering_impl.cc | 77 ++++++++- .../hlir/framework/pir/op_lowering_impl.h | 22 +++ paddle/cinn/hlir/pe/map_expr_to_ir.cc | 108 +++++++++---- paddle/cinn/hlir/pe/map_expr_to_ir.h | 1 + .../st_shape_group_scheduler.cc | 7 + .../group_schedule/st_shape_group_scheduler.h | 2 + paddle/cinn/pybind/frontend.cc | 2 - paddle/cinn/runtime/flags.cc | 10 +- .../framework/paddle2cinn/cinn_compiler.cc | 2 - test/cinn/CMakeLists.txt | 23 --- test/cinn/adt/test_add_inline.py | 60 ------- test/cinn/adt/test_broadcast_expr.py | 59 ------- test/cinn/adt/test_fusion_ability.py | 65 -------- test/cinn/adt/test_naive_add.py | 57 ------- test/cinn/adt/test_naive_reduce.py | 50 ------ test/cinn/adt/test_reduce_fusion.py | 62 ------- test/cinn/adt/test_reduce_schedule_mesh.py | 59 ------- 41 files changed, 623 insertions(+), 976 deletions(-) create mode 100644 paddle/cinn/adt/adapter_tensor.cc delete mode 100755 test/cinn/adt/test_add_inline.py delete mode 100755 test/cinn/adt/test_broadcast_expr.py delete mode 100644 test/cinn/adt/test_fusion_ability.py delete mode 100755 test/cinn/adt/test_naive_add.py delete mode 100644 test/cinn/adt/test_naive_reduce.py delete mode 100644 test/cinn/adt/test_reduce_fusion.py delete mode 100644 test/cinn/adt/test_reduce_schedule_mesh.py diff --git a/paddle/cinn/adt/CMakeLists.txt b/paddle/cinn/adt/CMakeLists.txt index 0997562ca548b1..b88284de4b772b 100644 --- a/paddle/cinn/adt/CMakeLists.txt +++ b/paddle/cinn/adt/CMakeLists.txt @@ -1,45 +1,49 @@ -core_gather_headers() +if(NOT CINN_ONLY) + core_gather_headers() -gather_srcs( - cinnapi_src - SRCS - anchor_sd_equation_context.cc - equation_function.cc - equation_solver.cc - equation_value.cc - generate_map_expr.cc - get_sub_reshape_dim_ranges.cc - graph_symbolic_dim_infer_ctx.cc - igroup.cc - index_expr_infer_context.cc - kgroup.cc - m_ir.cc - naive_bidirection_equation_generator.cc - naive_op_equation_context.cc - partition_op_stmts.cc - print_equations.cc - print_map_expr.cc - print_schedule_descriptor.cc - print_schedule_dim.cc - print_schedule_mesh.cc - print_value.cc - schedule_descriptor.cc - schedule_dim.cc - schedule_mesh.cc - dim_expr.cc - dim_expr_test.cc - dim_expr_simplifier.cc - symbolic_dim_infer_util.cc - simplify_value.cc - write_broadcast_disabled_bidirection_equation_generator.cc - print_dim_expr.cc) + gather_srcs( + cinnapi_src + SRCS + adapter_tensor.cc + anchor_sd_equation_context.cc + equation_function.cc + equation_solver.cc + equation_value.cc + generate_map_expr.cc + get_sub_reshape_dim_ranges.cc + graph_symbolic_dim_infer_ctx.cc + igroup.cc + index_expr_infer_context.cc + kgroup.cc + m_ir.cc + naive_bidirection_equation_generator.cc + naive_op_equation_context.cc + partition_op_stmts.cc + print_equations.cc + print_map_expr.cc + print_schedule_descriptor.cc + print_schedule_dim.cc + print_schedule_mesh.cc + print_value.cc + schedule_descriptor.cc + schedule_dim.cc + schedule_mesh.cc + dim_expr.cc + dim_expr_test.cc + dim_expr_simplifier.cc + symbolic_dim_infer_util.cc + simplify_value.cc + write_broadcast_disabled_bidirection_equation_generator.cc + print_dim_expr.cc) -cinn_cc_test(equation_value_match_trait_test SRCS - equation_value_match_trait_test.cc DEPS gtest glog) + cinn_cc_test(equation_value_match_trait_test SRCS + equation_value_match_trait_test.cc DEPS gtest glog) -cinn_cc_test(tree_test SRCS tree_test.cc DEPS gtest glog) + cinn_cc_test(tree_test SRCS tree_test.cc DEPS gtest glog) -cinn_cc_test(inline_translator_test SRCS inline_translator_test.cc DEPS gtest - glog) + cinn_cc_test(inline_translator_test SRCS inline_translator_test.cc DEPS gtest + glog) -message(STATUS "ADT srcs: ${cinnapi_src}") + message(STATUS "ADT srcs: ${cinnapi_src}") + +endif() diff --git a/paddle/cinn/adt/adapter_dynamic_tensor.h b/paddle/cinn/adt/adapter_dynamic_tensor.h index b5408a95401e50..a36c64b19fc449 100644 --- a/paddle/cinn/adt/adapter_dynamic_tensor.h +++ b/paddle/cinn/adt/adapter_dynamic_tensor.h @@ -17,37 +17,30 @@ #include "paddle/cinn/adt/adt.h" #include "paddle/cinn/adt/symbolic_dim.h" -#include "paddle/cinn/hlir/framework/graph.h" -#include "paddle/cinn/hlir/framework/node.h" +#include "paddle/cinn/hlir/framework/pir/group.h" namespace cinn::adt::adapter { struct DynamicTensor final { - const hlir::framework::NodeData* node_data; - const hlir::framework::Graph* graph; + ::pir::Value node_data; + const hlir::framework::pir::Group* group; bool operator==(const DynamicTensor& other) const { - return this->node_data == other.node_data && this->graph == other.graph; + return this->node_data == other.node_data; } std::size_t GetRank() const { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(node_data->id())) - << "Can't find " << node_data->id() << " 's shape!"; - return shape_dict.at(node_data->id()).size(); + return cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data) + .size(); } const std::vector>& GetShape() const { - return graph->graph_ctx()->GetTensorDimExprs(node_data); + return group->graph_symbolic_dim_infer_ctx()->GetTensorDimExprs(node_data); } }; inline std::size_t GetHashValueImpl(const DynamicTensor& tensor) { - return hash_combine( - std::hash()(tensor.node_data), - std::hash()(tensor.graph)); + return std::hash<::pir::Value>()(tensor.node_data); } } // namespace cinn::adt::adapter diff --git a/paddle/cinn/adt/adapter_tensor.cc b/paddle/cinn/adt/adapter_tensor.cc new file mode 100644 index 00000000000000..464c45780dbecd --- /dev/null +++ b/paddle/cinn/adt/adapter_tensor.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/adt/adapter_tensor.h" +#include "glog/logging.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" + +namespace cinn::adt::adapter { + +std::size_t Tensor::GetRank() const { + return cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data) + .size(); +} + +std::vector Tensor::GetShape() const { + std::vector ret{}; + for (int dim_size : + cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data)) { + ret.emplace_back(dim_size); + } + return ret; +} + +std::size_t Tensor::GetNumel() const { + std::size_t ret = 1; + for (int dim_size : + cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data)) { + ret = ret * dim_size; + } + return ret; +} + +} // namespace cinn::adt::adapter diff --git a/paddle/cinn/adt/adapter_tensor.h b/paddle/cinn/adt/adapter_tensor.h index 2a6cc941afb89e..dbd2c2dcecfdbb 100644 --- a/paddle/cinn/adt/adapter_tensor.h +++ b/paddle/cinn/adt/adapter_tensor.h @@ -13,59 +13,28 @@ // limitations under the License. #pragma once -#include "glog/logging.h" #include "paddle/cinn/adt/adt.h" -#include "paddle/cinn/hlir/framework/graph.h" -#include "paddle/cinn/hlir/framework/node.h" +#include "paddle/pir/core/value.h" namespace cinn::adt::adapter { struct Tensor final { - const hlir::framework::NodeData* node_data; - const hlir::framework::Graph* graph; + ::pir::Value node_data; bool operator==(const Tensor& other) const { - return this->node_data == other.node_data && this->graph == other.graph; + return this->node_data == other.node_data; } - std::size_t GetRank() const { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(node_data->id())) - << "Can't find " << node_data->id() << " 's shape!"; - return shape_dict.at(node_data->id()).size(); - } + std::size_t GetRank() const; - const std::vector& GetShape() const { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(node_data->id())) - << "Can't find " << node_data->id() << " 's shape!"; - return shape_dict.at(node_data->id()); - } + std::vector GetShape() const; - std::size_t GetNumel() const { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(node_data->id())) - << "Can't find " << node_data->id() << " 's shape!"; - std::vector shape = shape_dict.at(node_data->id()); - std::size_t ret = 1; - for (int32_t dim_size : shape) { - ret = ret * dim_size; - } - return ret; - } + std::size_t GetNumel() const; }; inline std::size_t GetHashValueImpl(const Tensor& tensor) { - return hash_combine( - std::hash()(tensor.node_data), - std::hash()(tensor.graph)); + return std::hash<::pir::Value>()(tensor.node_data); } } // namespace cinn::adt::adapter diff --git a/paddle/cinn/adt/equation_value.h b/paddle/cinn/adt/equation_value.h index a876ffef1bf6d4..1d0c2a134423a5 100644 --- a/paddle/cinn/adt/equation_value.h +++ b/paddle/cinn/adt/equation_value.h @@ -20,11 +20,6 @@ #include "paddle/cinn/adt/equation.h" #include "paddle/cinn/adt/match.h" -namespace cinn::hlir::framework { -class Node; -class NodeData; -} // namespace cinn::hlir::framework - namespace cinn::adt { DEFINE_ADT_TAG(tPointer); diff --git a/paddle/cinn/adt/generate_map_expr.cc b/paddle/cinn/adt/generate_map_expr.cc index 5782902c775f35..c5d235682ab497 100644 --- a/paddle/cinn/adt/generate_map_expr.cc +++ b/paddle/cinn/adt/generate_map_expr.cc @@ -28,7 +28,11 @@ #include "paddle/cinn/adt/schedule_descriptor.h" #include "paddle/cinn/adt/symbolic_dim_infer_util.h" #include "paddle/cinn/adt/tree.h" +#include "paddle/cinn/hlir/framework/pir/group.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/runtime/flags.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" #include "glog/logging.h" @@ -87,25 +91,24 @@ using LoopDescriptor4IterVarT = std::function; using AnchorTensor = Variable; using FakeOpPlaceHolders = List; -Op MakeOp(const hlir::framework::Node* op) { return {op}; } +Op MakeOp(const ::pir::Operation* op) { return {op}; } template -void VisitEachInputTensor(const hlir::framework::Node* op, - const DoEachT& DoEach) { - for (const auto& graph_edge : op->inlinks_in_order()) { - DoEach(graph_edge->source()->safe_as()); +void VisitEachInputTensor(const ::pir::Operation* op, const DoEachT& DoEach) { + for (std::size_t i = 0; i < op->num_operands(); ++i) { + DoEach(op->operand_source(i)); } } -List MakeOpStmtInputList(const hlir::framework::Node* op, - const hlir::framework::Graph* graph) { +List MakeOpStmtInputList(const ::pir::Operation* op, + const hlir::framework::pir::Group* group) { List ret{}; - VisitEachInputTensor(op, [&](const auto* tensor) { + VisitEachInputTensor(op, [&](const ::pir::Value& tensor) { if (FLAGS_cinn_map_expr_enable_dynamic_shape) { - ret->emplace_back(adapter::DynamicTensor{tensor, graph}); + ret->emplace_back(adapter::DynamicTensor{tensor, group}); } else { - ret->emplace_back(adapter::Tensor{tensor, graph}); + ret->emplace_back(adapter::Tensor{tensor}); } }); @@ -113,22 +116,21 @@ List MakeOpStmtInputList(const hlir::framework::Node* op, } template -void VisitEachOutputTensor(const hlir::framework::Node* op, - const DoEachT& DoEach) { - for (const auto& graph_edge : op->outlinks_in_order()) { - DoEach(graph_edge->sink()->safe_as()); +void VisitEachOutputTensor(const ::pir::Operation* op, const DoEachT& DoEach) { + for (std::size_t i = 0; i < op->num_results(); ++i) { + DoEach(const_cast<::pir::Operation*>(op)->result(i)); } } -List MakeOpStmtOutputList(const hlir::framework::Node* op, - const hlir::framework::Graph* graph) { +List MakeOpStmtOutputList(const ::pir::Operation* op, + const hlir::framework::pir::Group* group) { List ret{}; - VisitEachOutputTensor(op, [&](const auto* tensor) { + VisitEachOutputTensor(op, [&](const ::pir::Value& tensor) { if (FLAGS_cinn_map_expr_enable_dynamic_shape) { - ret->emplace_back(adapter::DynamicTensor{tensor, graph}); + ret->emplace_back(adapter::DynamicTensor{tensor, group}); } else { - ret->emplace_back(adapter::Tensor{tensor, graph}); + ret->emplace_back(adapter::Tensor{tensor}); } }); @@ -136,38 +138,30 @@ List MakeOpStmtOutputList(const hlir::framework::Node* op, } template -void VisitEachOpStmt( - const std::shared_ptr& group, - const DoEachT& DoEach) { - // Note - for (const auto* op : group->CollectNodes()) { +void VisitEachOpStmt(const std::shared_ptr& group, + const DoEachT& DoEach) { + for (const auto* op : group->CollectOps()) { DoEach(OpStmt{MakeOp(op), - MakeOpStmtInputList(op, group->graph_), - MakeOpStmtOutputList(op, group->graph_)}); + MakeOpStmtInputList(op, group.get()), + MakeOpStmtOutputList(op, group.get())}); } } -hlir::framework::OpPatternKind GetOpPatternKind( - const hlir::framework::Node* node) { - static const hlir::framework::OpValueType& - op_pattern_dict = - hlir::framework::Operator::GetAttrs( - "OpPattern"); - auto kind = op_pattern_dict[node->op()]; - return kind; +hlir::framework::OpPatternKind GetOpPatternKind(const ::pir::Operation* node) { + return hlir::framework::pir::CompatibleInfo::OpKind(*node); } bool CollectRewritedReductionOpStmts(const OpStmt& op_stmt, List* ret) { const auto& [op, inputs, outputs] = op_stmt.tuple(); - CHECK(op.Has()); - if (GetOpPatternKind(op.Get()) == + CHECK(op.Has()); + if (GetOpPatternKind(op.Get()) == hlir::framework::OpPatternKind::kReduction) { - tReduceInit init_op{ - op.Get()}; + tReduceInit init_op{ + op.Get()}; (*ret)->emplace_back(OpStmt{init_op, List{}, outputs}); - tReduceAcc acc_op{ - op.Get()}; + tReduceAcc acc_op{ + op.Get()}; (*ret)->emplace_back(OpStmt{acc_op, inputs, outputs}); return true; } else { @@ -183,7 +177,7 @@ void CollectRewritedOpStmts(const OpStmt& op_stmt, List* ret) { } List MakeOpStmts( - const std::shared_ptr& group) { + const std::shared_ptr& group) { List ret{}; VisitEachOpStmt(group, [&](const auto& op_stmt) { @@ -212,15 +206,14 @@ std::shared_ptr MakeIGroup(const AnchorGroup& igroup_spec) { std::shared_ptr direction_equation_generator{ new NaiveBidirectionEquationGenerator{igroup_spec.op_stmts, igroup_spec.EquationCtx4OpStmt}}; - CheckEquationSolvable( - igroup_spec, direction_equation_generator); + CheckEquationSolvable(igroup_spec, direction_equation_generator); return std::make_shared(igroup_spec.op_stmts, igroup_spec.anchor_index, igroup_spec.EquationCtx4OpStmt); } std::vector> GenerateIGroups( - const std::shared_ptr& group) { + const std::shared_ptr& group) { std::vector> ret{}; List op_stmts = MakeOpStmts(group); @@ -234,7 +227,7 @@ std::vector> GenerateIGroups( } std::shared_ptr GenerateKGroups( - const std::shared_ptr& group, + const std::shared_ptr& group, const std::vector>& igroups) { CHECK_EQ(igroups.size(), 1); return std::make_shared(group, igroups); @@ -274,8 +267,7 @@ std::shared_ptr SolveEquationsThenReturnCtx( GraphView merged_view = igroup_view.Merge(sd_equation_graph_view); const auto& init_var2value = MakeSdIterator2Iterator(*igroup); - auto ctx = std::make_shared( - init_var2value); + auto ctx = std::make_shared(init_var2value); std::vector starts{}; for (const auto& loop_iterator : *igroup->loop_iterators()) { @@ -350,36 +342,34 @@ Tensor GetAnchorTensor(const std::shared_ptr& igroup) { } template -void VisitInputTensor(const hlir::framework::Graph::Group& group, +void VisitInputTensor(const hlir::framework::pir::Group& group, const DoEachT& DoEach) { - for (const auto* node_data : group.GetInputNodeDatas()) { - DoEach(node_data, group.graph_); + for (const ::pir::Value& node_data : group.GetInputOpValues()) { + DoEach(node_data); } } template -void VisitOutputTensor(const hlir::framework::Graph::Group& group, +void VisitOutputTensor(const hlir::framework::pir::Group& group, const DoEachT& DoEach) { - for (const auto& node_data : group.GetOutputNodeDatas()) { - DoEach(node_data, group.graph_); + for (const ::pir::Value& node_data : group.GetOutputOpValues()) { + DoEach(node_data); } } List MakeInputTensors(const std::shared_ptr& kgroup) { List ret{}; - VisitInputTensor(*kgroup->cinn_group(), - [&](const auto* node_data, const auto* graph) { - ret->emplace_back(adapter::Tensor{node_data, graph}); - }); + VisitInputTensor(*kgroup->cinn_group(), [&](const ::pir::Value& node_data) { + ret->emplace_back(adapter::Tensor{node_data}); + }); return ret; } List MakeOutputTensors(const std::shared_ptr& kgroup) { List ret{}; - VisitOutputTensor(*kgroup->cinn_group(), - [&](const auto* node_data, const auto* graph) { - ret->emplace_back(adapter::Tensor{node_data, graph}); - }); + VisitOutputTensor(*kgroup->cinn_group(), [&](const ::pir::Value& node_data) { + ret->emplace_back(adapter::Tensor{node_data}); + }); return ret; } @@ -444,7 +434,7 @@ MapExpr GenerateMapExpr(const std::shared_ptr& kgroup) { } // namespace MapExpr GenerateMapExpr( - const std::shared_ptr& group) { + const std::shared_ptr& group) { const auto& igroups = GenerateIGroups(group); const auto& kgroup = GenerateKGroups(group, igroups); @@ -453,16 +443,29 @@ MapExpr GenerateMapExpr( } void TryGenerateMapExprFromGraph( - const std::shared_ptr& graph) { + const hlir::framework::pir::GroupList& groups) { if (!FLAGS_cinn_enable_map_expr) { return; } - graph->set_graph_ctx(adt::InferSymbolicDim(graph.get())); - for (const auto& fusion_group : graph->fusion_groups) { + for (const auto& fusion_group : groups) { + fusion_group->set_graph_symbolic_dim_infer_ctx( + adt::InferSymbolicDim(fusion_group.get())); const auto& map_expr = GenerateMapExpr(fusion_group); VLOG(1) << ToTxtString(map_expr, fusion_group->group_id); fusion_group->set_map_expr_ctx(std::make_shared(map_expr)); } } +void TryGenerateMapExprFromGroup( + const std::shared_ptr& fusion_group) { + if (!FLAGS_cinn_enable_map_expr) { + return; + } + fusion_group->set_graph_symbolic_dim_infer_ctx( + adt::InferSymbolicDim(fusion_group.get())); + const auto& map_expr = GenerateMapExpr(fusion_group); + VLOG(1) << ToTxtString(map_expr, fusion_group->group_id); + fusion_group->set_map_expr_ctx(std::make_shared(map_expr)); +} + } // namespace cinn::adt diff --git a/paddle/cinn/adt/generate_map_expr.h b/paddle/cinn/adt/generate_map_expr.h index cd2970559cb6b0..e9235dd625270b 100644 --- a/paddle/cinn/adt/generate_map_expr.h +++ b/paddle/cinn/adt/generate_map_expr.h @@ -14,19 +14,25 @@ #pragma once -#include "paddle/cinn/adt/m_ir.h" +#include + #include "paddle/cinn/adt/map_expr.h" -#include "paddle/cinn/hlir/framework/graph.h" -namespace cinn::adt { +namespace cinn::hlir::framework::pir { + +struct Group; +using GroupList = std::vector>; -class IGroup; -class KGroup; +} // namespace cinn::hlir::framework::pir + +namespace cinn::adt { MapExpr GenerateMapExpr( - const std::shared_ptr& group); + const std::shared_ptr& group); + +void TryGenerateMapExprFromGraph(const hlir::framework::pir::GroupList& groups); -void TryGenerateMapExprFromGraph( - const std::shared_ptr& graph); +void TryGenerateMapExprFromGroup( + const std::shared_ptr& fusion_group); } // namespace cinn::adt diff --git a/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc b/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc index 83fc92a203c0f6..0a4df9532fe2c7 100644 --- a/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc +++ b/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc @@ -16,57 +16,41 @@ #include "paddle/cinn/adt/dim_expr_simplifier.h" #include "paddle/cinn/adt/unique_id.h" -#include "paddle/cinn/common/graph_utils.h" -#include "paddle/cinn/hlir/framework/graph.h" -#include "paddle/cinn/hlir/framework/node.h" +#include "paddle/cinn/hlir/framework/pir/group.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" namespace cinn::adt::config { namespace { -const std::vector& GetShape(const hlir::framework::Graph* graph, - const hlir::framework::NodeData* tensor) { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(tensor->id())) - << "Can't find " << tensor->id() << " 's shape!"; - return shape_dict.at(tensor->id()); +std::vector GetShape(const ::pir::Value& tensor) { + std::vector tensor_shape = + hlir::framework::pir::CompatibleInfo::ValueShape(tensor); + std::vector ret{}; + for (int32_t dim : tensor_shape) { + ret.push_back(dim); + } + return ret; } -std::size_t GetTensorRank(const hlir::framework::Graph* graph, - const hlir::framework::NodeData* tensor) { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(tensor->id())) - << "Can't find " << tensor->id() << " 's shape!"; - return shape_dict.at(tensor->id()).size(); +std::size_t GetTensorRank(const ::pir::Value& tensor) { + return hlir::framework::pir::CompatibleInfo::ValueShape(tensor).size(); } -std::vector GetOpInputRanks(const hlir::framework::Graph* graph, - const hlir::framework::Node* node) { +std::vector GetOpInputRanks(const ::pir::Operation* node) { std::vector ret{}; - for (const auto& graph_edge : node->inlinks_in_order()) { - const hlir::framework::NodeData* tensor = - graph_edge->source()->safe_as(); - ret.emplace_back(GetTensorRank(graph, tensor)); + for (const ::pir::Value& tensor : node->operands_source()) { + ret.emplace_back(GetTensorRank(tensor)); } return ret; } -std::vector GetTopoOrderOpNodes( - const hlir::framework::Graph* graph) { - std::vector ret{}; - std::vector topo_nodes = - std::get<0>(graph->topological_order()); - for (const common::GraphNode* graph_node : topo_nodes) { - const hlir::framework::Node* op_node = - graph_node->safe_as(); - // if node is NodeData or not op, continue. - if (!op_node || op_node->op() == nullptr) { - continue; - } +std::vector GetTopoOrderOpNodes( + const hlir::framework::pir::Group* group) { + std::vector ret{}; + for (const ::pir::Operation* op_node : group->ops) { ret.emplace_back(op_node); } return ret; @@ -75,8 +59,8 @@ std::vector GetTopoOrderOpNodes( } // namespace void GraphSymbolicDimInferCtx::InitOp2TensorRanks() { - for (const hlir::framework::Node* op_node : GetTopoOrderOpNodes(graph_)) { - const auto& input_ranks = GetOpInputRanks(graph_, op_node); + for (const ::pir::Operation* op_node : GetTopoOrderOpNodes(group_)) { + std::vector input_ranks = GetOpInputRanks(op_node); if (op2input_ranks_.find(op_node) == op2input_ranks_.end()) { op2input_ranks_.emplace(op_node, input_ranks); } else { @@ -88,30 +72,31 @@ void GraphSymbolicDimInferCtx::InitOp2TensorRanks() { namespace { std::unordered_set GetAllOutputNames( - const std::vector& nodes) { + const std::vector& nodes) { std::unordered_set output_names; - for (const auto* node : nodes) { - for (const auto& link : node->outlinks()) { - const auto* out_node = link->sink()->safe_as(); - output_names.emplace(out_node->id()); + for (const auto* op_node : nodes) { + for (const ::pir::Value& out_node : + const_cast<::pir::Operation*>(op_node)->results()) { + output_names.emplace( + hlir::framework::pir::CompatibleInfo::ValueName(out_node)); } } return output_names; } -std::vector GetFeedList( - const std::vector& nodes, +std::vector<::pir::Value> GetFeedList( + const std::vector& op_nodes, const std::unordered_set& out_names) { - std::vector ret{}; + std::vector<::pir::Value> ret{}; // if the op's input var name cannot found in out_names, it is the group's // feed var std::unordered_set feed_names; - for (const auto* node : nodes) { - for (const auto& link : node->inlinks()) { - const auto* in_node = - link->source()->safe_as(); - if (!out_names.count(in_node->id()) && !feed_names.count(in_node->id())) { - feed_names.emplace(in_node->id()); + for (const auto* op_node : op_nodes) { + for (const ::pir::Value in_node : op_node->operands_source()) { + const auto& node_id = + hlir::framework::pir::CompatibleInfo::ValueName(in_node); + if (!out_names.count(node_id) && !feed_names.count(node_id)) { + feed_names.emplace(node_id); ret.emplace_back(in_node); } } @@ -120,15 +105,13 @@ std::vector GetFeedList( } std::vector> MakeDimExprForTensor( - const hlir::framework::Graph* graph, - const hlir::framework::NodeData* node_data) { + const ::pir::Value& node_data) { std::vector> ret{}; - const std::vector& shape = GetShape(graph, node_data); + std::vector shape = GetShape(node_data); for (std::size_t i = 0; i < shape.size(); ++i) { if (i == 0) { - static DimExpr temp_elementwise_dim_expr{ - SymbolicDim{UniqueId::New()}}; + static DimExpr temp_elementwise_dim_expr{SymbolicDim{UniqueId::New()}}; ret.emplace_back(temp_elementwise_dim_expr); } else { ret.emplace_back(DimExpr{shape.at(i)}); @@ -140,38 +123,34 @@ std::vector> MakeDimExprForTensor( } // namespace void GraphSymbolicDimInferCtx::InitGraphInputDimExpr() { - std::vector topo_op_nodes = - GetTopoOrderOpNodes(graph_); - std::vector feed_list = + std::vector topo_op_nodes = + GetTopoOrderOpNodes(group_); + std::vector<::pir::Value> feed_list = GetFeedList(topo_op_nodes, GetAllOutputNames(topo_op_nodes)); - for (const hlir::framework::NodeData* node_data : feed_list) { - CHECK( - tensor2dim_exprs_ - .emplace(node_data, MakeDimExprForTensor(graph_, node_data)) - .second); + for (const ::pir::Value node_data : feed_list) { + CHECK(tensor2dim_exprs_.emplace(node_data, MakeDimExprForTensor(node_data)) + .second); } } const std::vector& GraphSymbolicDimInferCtx::GetInTensorsRanks( - const hlir::framework::Node* node) const { + const ::pir::Operation* node) const { const auto& iter = op2input_ranks_.find(node); CHECK(iter != op2input_ranks_.end()); return iter->second; } std::uint64_t GraphSymbolicDimInferCtx::GetNumOutTensors( - const hlir::framework::Node* node) const { - return node->outlinks_in_order().size(); + const ::pir::Operation* node) const { + return node->num_results(); } const DimExpr& GraphSymbolicDimInferCtx::GetInputDimExpr( - const hlir::framework::Node* node, + const ::pir::Operation* node, std::size_t arg_idx, std::size_t dim_idx) const { - const auto& edges = node->inlinks_in_order(); - CHECK_LT(arg_idx, edges.size()); - const hlir::framework::NodeData* tensor = - edges.at(arg_idx)->source()->safe_as(); + CHECK_LT(arg_idx, node->num_operands()); + const ::pir::Value tensor = node->operand_source(arg_idx); const auto& iter = tensor2dim_exprs_.find(tensor); CHECK(iter != tensor2dim_exprs_.end()); CHECK_LT(dim_idx, iter->second.size()); @@ -180,16 +159,14 @@ const DimExpr& GraphSymbolicDimInferCtx::GetInputDimExpr( return opt_dim_expr.value(); } -void GraphSymbolicDimInferCtx::SetOutputDimExpr( - const hlir::framework::Node* node, - std::size_t arg_idx, - std::size_t dim_idx, - const DimExpr& value) { - const auto& edges = node->outlinks_in_order(); - CHECK_LT(arg_idx, edges.size()); - const hlir::framework::NodeData* tensor = - edges.at(arg_idx)->sink()->safe_as(); - std::size_t rank = GetTensorRank(graph_, tensor); +void GraphSymbolicDimInferCtx::SetOutputDimExpr(const ::pir::Operation* node, + std::size_t arg_idx, + std::size_t dim_idx, + const DimExpr& value) { + CHECK_LT(arg_idx, node->num_results()); + const ::pir::Value tensor = + const_cast<::pir::Operation*>(node)->result(arg_idx); + std::size_t rank = GetTensorRank(tensor); CHECK_LT(dim_idx, rank); auto* opt_symbolic_dims = &tensor2dim_exprs_[tensor]; if (dim_idx >= opt_symbolic_dims->size()) { @@ -198,9 +175,9 @@ void GraphSymbolicDimInferCtx::SetOutputDimExpr( opt_symbolic_dims->at(dim_idx) = SimplifyDimExpr(value); } -const hlir::framework::AttrMapType& GraphSymbolicDimInferCtx::GetAttributeMap( - const hlir::framework::Node* op_node) const { - return op_node->attrs.attr_store; +cinn::utils::AttributeMap GraphSymbolicDimInferCtx::GetAttributeMap( + const ::pir::Operation* op_node) const { + return hlir::framework::pir::CompatibleInfo::ConvertAttributes(*op_node); } } // namespace cinn::adt::config diff --git a/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h b/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h index 58bda1f4e9b2eb..37e040508b69d1 100644 --- a/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h +++ b/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h @@ -19,13 +19,15 @@ #include #include "paddle/cinn/adt/dim_expr.h" -#include "paddle/cinn/hlir/framework/node.h" +#include "paddle/cinn/utils/type_defs.h" +#include "paddle/pir/core/value.h" +namespace pir { +class Operation; +} -namespace cinn::hlir::framework { -class Graph; -class NodeData; -class Node; -} // namespace cinn::hlir::framework +namespace cinn::hlir::framework::pir { +struct Group; +} // namespace cinn::hlir::framework::pir namespace cinn::adt::config { @@ -34,47 +36,46 @@ class GraphSymbolicDimInferCtx { GraphSymbolicDimInferCtx(const GraphSymbolicDimInferCtx&) = delete; GraphSymbolicDimInferCtx(GraphSymbolicDimInferCtx&&) = delete; - explicit GraphSymbolicDimInferCtx(const hlir::framework::Graph* graph) - : graph_(graph) { + explicit GraphSymbolicDimInferCtx( + const cinn::hlir::framework::pir::Group* group) + : group_(group) { InitOp2TensorRanks(); InitGraphInputDimExpr(); } - const hlir::framework::Graph* graph() const { return graph_; } + const cinn::hlir::framework::pir::Group* group() const { return group_; } const std::vector& GetInTensorsRanks( - const hlir::framework::Node* node) const; + const ::pir::Operation* node) const; - std::uint64_t GetNumOutTensors(const hlir::framework::Node* node) const; + std::uint64_t GetNumOutTensors(const ::pir::Operation* node) const; - const DimExpr& GetInputDimExpr(const hlir::framework::Node* node, - std::size_t arg_idx, - std::size_t dim_idx) const; + const DimExpr& GetInputDimExpr(const ::pir::Operation* node, + std::size_t arg_idx, + std::size_t dim_idx) const; const std::vector>& GetTensorDimExprs( - const hlir::framework::NodeData* tensor) const { + const ::pir::Value tensor) const { const auto& iter = tensor2dim_exprs_.find(tensor); CHECK(iter != tensor2dim_exprs_.end()); return iter->second; } - void SetOutputDimExpr(const hlir::framework::Node* node, + void SetOutputDimExpr(const ::pir::Operation* node, std::size_t arg_idx, std::size_t dim_idx, const DimExpr& value); - const hlir::framework::AttrMapType& GetAttributeMap( - const hlir::framework::Node* node) const; + cinn::utils::AttributeMap GetAttributeMap(const ::pir::Operation* node) const; private: void InitOp2TensorRanks(); void InitGraphInputDimExpr(); - const hlir::framework::Graph* graph_; - std::unordered_map>> + const cinn::hlir::framework::pir::Group* group_; + std::unordered_map<::pir::Value, std::vector>> tensor2dim_exprs_; - std::unordered_map> + std::unordered_map> op2input_ranks_; }; diff --git a/paddle/cinn/adt/kgroup.h b/paddle/cinn/adt/kgroup.h index 9227e5500cb9ae..ba071ed86f242f 100644 --- a/paddle/cinn/adt/kgroup.h +++ b/paddle/cinn/adt/kgroup.h @@ -19,6 +19,12 @@ #include "paddle/cinn/adt/map_expr.h" +namespace cinn::hlir::framework::pir { + +struct Group; + +} // namespace cinn::hlir::framework::pir + namespace cinn::adt { class IGroup; @@ -27,11 +33,11 @@ using cinn::adt::LoopDescriptors; class KGroup final { public: explicit KGroup( - const std::shared_ptr& cinn_group, + const std::shared_ptr& cinn_group, const std::vector>& igroups) : cinn_group_(cinn_group), igroups_(igroups) {} - std::shared_ptr cinn_group() const { + std::shared_ptr cinn_group() const { return CHECK_NOTNULL(cinn_group_.lock()); } @@ -46,7 +52,7 @@ class KGroup final { const std::shared_ptr& igroup) const; private: - std::weak_ptr cinn_group_; + std::weak_ptr cinn_group_; // NOTE: Use single igroup temporarily. Actually KGroup contains // multiple IGroups std::vector> igroups_; diff --git a/paddle/cinn/adt/map_expr.h b/paddle/cinn/adt/map_expr.h index de7ebe39acbdb6..05cfd7ef277e8f 100644 --- a/paddle/cinn/adt/map_expr.h +++ b/paddle/cinn/adt/map_expr.h @@ -27,14 +27,9 @@ #include "paddle/cinn/adt/tags.h" #include "paddle/cinn/adt/tree.h" -namespace cinn { -namespace hlir { -namespace framework { -class Node; -class NodeData; -} // namespace framework -} // namespace hlir -} // namespace cinn +namespace pir { +class Operation; +} namespace cinn { namespace adt { @@ -90,11 +85,13 @@ DEFINE_ADT_UNION(Tensor, adapter::Tensor, adapter::DynamicTensor, TempStorage); OVERRIDE_UNION_GET_HASH_VALUE(Tensor); OVERLOAD_OPERATOR_EQ_NE(Tensor, UnionEqual); -// Op = const Node* +// Op = const pir::Operation* +// | tReduceInit +// | tReduceAcc DEFINE_ADT_UNION(Op, - const hlir::framework::Node*, - tReduceInit, - tReduceAcc); + const ::pir::Operation*, + tReduceInit, + tReduceAcc); using Arg = Tensor; diff --git a/paddle/cinn/adt/map_expr_ctx.h b/paddle/cinn/adt/map_expr_ctx.h index 53978129228fa0..75f56155b72a4e 100644 --- a/paddle/cinn/adt/map_expr_ctx.h +++ b/paddle/cinn/adt/map_expr_ctx.h @@ -18,16 +18,16 @@ #include #include "paddle/cinn/adt/map_expr.h" -#include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/ir/lowered_func.h" #include "paddle/cinn/ir/utils/ir_copy.h" +#include "paddle/pir/core/operation.h" namespace cinn::adt { class MapExprCtx final { public: using Node2LoweredFuncs = - std::unordered_map>; + std::unordered_map<::pir::Operation*, std::vector>; MapExprCtx(const MapExprCtx&) = delete; MapExprCtx(MapExprCtx&&) = delete; @@ -37,7 +37,7 @@ class MapExprCtx final { const MapExpr& map_expr() const { return map_expr_; } void UpdateOpLoweredFuncKey( - hlir::framework::Node* node, + ::pir::Operation* node, const std::vector& lowered_funcs) { Node2LoweredFuncs* map = &node2lowered_funcs_; CHECK(map->emplace(node, ir::ir_utils::IRCopy(lowered_funcs)).second); diff --git a/paddle/cinn/adt/naive_op_equation_context.cc b/paddle/cinn/adt/naive_op_equation_context.cc index 21524a20c54a7c..880bda82857967 100644 --- a/paddle/cinn/adt/naive_op_equation_context.cc +++ b/paddle/cinn/adt/naive_op_equation_context.cc @@ -19,6 +19,8 @@ #include "paddle/cinn/adt/naive_op_equation_context.h" #include "paddle/cinn/adt/op_arg_pos.h" #include "paddle/cinn/adt/print.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/cinn/utils/type_defs.h" #include "glog/logging.h" @@ -55,7 +57,7 @@ std::vector MakeTensorRanks(const List& arg_lists) { return ret; } -void GenerateOpEquationsImpl(const hlir::framework::Node* op_node, +void GenerateOpEquationsImpl(const ::pir::Operation* op_node, const OpStmt& op_stmt, config::NaiveOpEquationContext* ctx) { const auto& [_, inputs, outputs] = op_stmt.tuple(); @@ -66,19 +68,22 @@ void GenerateOpEquationsImpl(const hlir::framework::Node* op_node, const auto& generate_equations = hlir::framework::Operator::GetAttrs( "generate_equations"); - CHECK(generate_equations.Find(op_node->op())); - generate_equations[op_node->op()](ctx); + const hlir::framework::Operator* cinn_op = hlir::framework::Operator::Get( + hlir::framework::pir::CompatibleInfo::OpName(*op_node)); + CHECK(generate_equations.Find(cinn_op)); + generate_equations[cinn_op](ctx); } -std::optional GetArgStaticDimSize( - const List& tensors, std::size_t tensor_idx, std::size_t dim_idx) { +std::optional GetArgStaticDimSize(const List& tensors, + std::size_t tensor_idx, + std::size_t dim_idx) { if (!tensors->at(tensor_idx).Has()) { return std::nullopt; } if (tensor_idx >= tensors->size()) { return std::nullopt; } - const auto& tensor_shape = + const std::vector tensor_shape = tensors->at(tensor_idx).Get().GetShape(); if (dim_idx >= tensor_shape.size()) { return std::nullopt; @@ -86,15 +91,16 @@ std::optional GetArgStaticDimSize( return tensor_shape.at(dim_idx); } -std::optional GetArgDimExpr( - const List& tensors, std::size_t tensor_idx, std::size_t dim_idx) { +std::optional GetArgDimExpr(const List& tensors, + std::size_t tensor_idx, + std::size_t dim_idx) { if (!tensors->at(tensor_idx).Has()) { return std::nullopt; } if (tensor_idx >= tensors->size()) { return std::nullopt; } - const auto& tensor_shape = + const std::vector tensor_shape = tensors->at(tensor_idx).Get().GetShape(); if (dim_idx >= tensor_shape.size()) { return std::nullopt; @@ -103,8 +109,9 @@ std::optional GetArgDimExpr( return tensor_shape.at(dim_idx); } -std::optional GetArgDim( - const List& tensors, std::size_t tensor_idx, std::size_t dim_idx) { +std::optional GetArgDim(const List& tensors, + std::size_t tensor_idx, + std::size_t dim_idx) { const auto& opt_expr = GetArgDimExpr(tensors, tensor_idx, dim_idx); if (opt_expr.has_value()) { return opt_expr; @@ -135,15 +142,14 @@ GetArgSymbolicDimT MakeGetterArgSymbolicDim(const List& tensors) { }; } -void GenerateOpEquationsImpl( - const tReduceAcc& op_node, - const OpStmt& op_stmt, - config::NaiveOpEquationContext* ctx) { +void GenerateOpEquationsImpl(const tReduceAcc& op_node, + const OpStmt& op_stmt, + config::NaiveOpEquationContext* ctx) { GenerateOpEquationsImpl(op_node.value(), op_stmt, ctx); } void GenerateOpEquationsImpl( - const tReduceInit& op_node, + const tReduceInit& op_node, const OpStmt& op_stmt, config::NaiveOpEquationContext* ctx) { // Do nothing @@ -160,29 +166,25 @@ void GenerateOpEquations(const OpStmt& op_stmt, op.variant()); } -const hlir::framework::AttrMapType* GetOpAttrImpl( - const hlir::framework::Node* op_node) { - return &op_node->attrs.attr_store; +cinn::utils::AttributeMap GetOpAttrImpl(const ::pir::Operation* op_node) { + return hlir::framework::pir::CompatibleInfo::ConvertAttributes(*op_node); } -const hlir::framework::AttrMapType* GetOpAttrImpl( - const tReduceInit&) { - static hlir::framework::AttrMapType empty{}; - return ∅ +cinn::utils::AttributeMap GetOpAttrImpl( + const tReduceInit&) { + return cinn::utils::AttributeMap{}; } -const hlir::framework::AttrMapType* GetOpAttrImpl( - const tReduceAcc& op_node) { +cinn::utils::AttributeMap GetOpAttrImpl( + const tReduceAcc& op_node) { return GetOpAttrImpl(op_node.value()); } -const hlir::framework::AttrMapType* GetOpAttr(const OpStmt& op_stmt) { +cinn::utils::AttributeMap GetOpAttr(const OpStmt& op_stmt) { const auto& [op_node, inputs, outputs] = op_stmt.tuple(); - const auto* attr = std::visit( - [&](const auto& impl) { return GetOpAttrImpl(impl); }, op_node.variant()); - - return attr; + return std::visit([&](const auto& impl) { return GetOpAttrImpl(impl); }, + op_node.variant()); } std::shared_ptr MakeContextAndGenerateEquations( diff --git a/paddle/cinn/adt/naive_op_equation_context.h b/paddle/cinn/adt/naive_op_equation_context.h index 7f2b2bc8fe95ee..ce8018889eb6a9 100644 --- a/paddle/cinn/adt/naive_op_equation_context.h +++ b/paddle/cinn/adt/naive_op_equation_context.h @@ -48,7 +48,7 @@ class NaiveOpEquationContext final : public OpEquationContext { GetArgStaticDimT GetOutDim, GetArgSymbolicDimT GetSymbolicInDim, GetArgSymbolicDimT GetSymbolicOutDim, - const hlir::framework::AttrMapType* attr_map_type) + cinn::utils::AttributeMap attr_map_type) : in_tensors_ranks_(in_tensors_ranks), out_tensors_ranks_(out_tensors_ranks), GetInDim_(GetInDim), @@ -227,8 +227,8 @@ class NaiveOpEquationContext final : public OpEquationContext { } std::optional GetSymbolicDimSize(bool is_out, - std::size_t arg_idx, - std::size_t axis) const { + std::size_t arg_idx, + std::size_t axis) const { const auto* Get = (is_out ? &GetSymbolicOutDim_ : &GetSymbolicInDim_); const auto& opt_dim = (*Get)(arg_idx, axis); return opt_dim; @@ -247,7 +247,8 @@ class NaiveOpEquationContext final : public OpEquationContext { } } - void InitInputDimExpr(std::vector* vec, const std::vector& tensors_ranks) { + void InitInputDimExpr(std::vector* vec, + const std::vector& tensors_ranks) { for (std::size_t i = 0; i < tensors_ranks.size(); ++i) { vec->push_back(DimTuple{}); for (std::size_t j = 0; j < tensors_ranks.at(i); ++j) { @@ -258,7 +259,8 @@ class NaiveOpEquationContext final : public OpEquationContext { } } - void InitOutputDimExpr(std::vector* vec, const std::vector& tensors_ranks) { + void InitOutputDimExpr(std::vector* vec, + const std::vector& tensors_ranks) { for (std::size_t i = 0; i < tensors_ranks.size(); ++i) { vec->push_back(DimTuple{}); for (std::size_t j = 0; j < tensors_ranks.at(i); ++j) { @@ -318,8 +320,8 @@ class NaiveOpEquationContext final : public OpEquationContext { } const utils::Attribute& GetAttribute(const std::string& name) const { - const auto& iter = attr_map_type_->find(name); - CHECK(iter != attr_map_type_->end()) + const auto& iter = attr_map_type_.find(name); + CHECK(iter != attr_map_type_.end()) << "Can't find Attribute with this name"; return iter->second; } @@ -331,7 +333,7 @@ class NaiveOpEquationContext final : public OpEquationContext { GetArgSymbolicDimT GetSymbolicInDim_; GetArgSymbolicDimT GetSymbolicOutDim_; Equations equations_; - const hlir::framework::AttrMapType* attr_map_type_; + const cinn::utils::AttributeMap attr_map_type_; FakeOpPlaceHolder fake_op_placeholder_; std::vector in_iterator_tuples_; diff --git a/paddle/cinn/adt/print_equations.cc b/paddle/cinn/adt/print_equations.cc index 859957b6d755d3..2e3c5ca72bf4bd 100644 --- a/paddle/cinn/adt/print_equations.cc +++ b/paddle/cinn/adt/print_equations.cc @@ -13,25 +13,29 @@ // limitations under the License. #include "paddle/cinn/adt/print_equations.h" -#include "paddle/cinn/adt/print_dim_expr.h" #include #include #include "paddle/cinn/adt/equation_function.h" +#include "paddle/cinn/adt/print_dim_expr.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/pir/core/operation.h" namespace cinn::adt { namespace { -std::string OpImpl(const hlir::framework::Node* op) { return op->op()->name; } +std::string OpImpl(const ::pir::Operation* op) { + return hlir::framework::pir::CompatibleInfo::OpName(*op); +} -std::string OpImpl(const tReduceInit& op) { - return op.value()->op()->name + "_init"; +std::string OpImpl(const tReduceInit& op) { + return OpImpl(op.value()) + "_init"; } -std::string OpImpl(const tReduceAcc& op) { - return op.value()->op()->name + "_acc"; +std::string OpImpl(const tReduceAcc& op) { + return OpImpl(op.value()) + "_acc"; } } // namespace @@ -185,8 +189,8 @@ struct ToTxtStringStruct { } std::string operator()( - const IndexUnDot, tOut>, tIn>& - undot) const { + const IndexUnDot, tOut>, tIn>& undot) + const { std::string ret; const auto& [dim_list, out_iterator_list_tag, in_index_tag] = undot.tuple(); const List& out_iterator_list = out_iterator_list_tag.value(); diff --git a/paddle/cinn/adt/print_map_expr.cc b/paddle/cinn/adt/print_map_expr.cc index 1e8d82478ad6e4..e25317faf8ba1d 100644 --- a/paddle/cinn/adt/print_map_expr.cc +++ b/paddle/cinn/adt/print_map_expr.cc @@ -21,6 +21,7 @@ #include "paddle/cinn/adt/print_schedule_mesh.h" #include "paddle/cinn/adt/print_value.h" #include "paddle/cinn/adt/schedule_descriptor.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/runtime/flags.h" PD_DECLARE_bool(cinn_map_expr_enable_index_detail); @@ -58,14 +59,14 @@ namespace { std::string ToTxtStringImpl(const adapter::Tensor& tensor) { std::string ret; ret += "t_"; - ret += tensor.node_data->id(); + ret += hlir::framework::pir::CompatibleInfo::ValueName(tensor.node_data); return ret; } std::string ToTxtStringImpl(const adapter::DynamicTensor& tensor) { std::string ret; ret += "t_"; - ret += tensor.node_data->id(); + ret += hlir::framework::pir::CompatibleInfo::ValueName(tensor.node_data); return ret; } @@ -109,18 +110,16 @@ std::string ArgsToTxtString(const List& out_args, return ArgsToTxtString(out_args, in_args, GetEmptyStr, GetEmptyStr); } -std::string ToTxtStringOpImpl(const hlir::framework::Node* op) { - return op->op()->name; +std::string ToTxtStringOpImpl(const ::pir::Operation* op) { + return hlir::framework::pir::CompatibleInfo::OpName(*op); } -std::string ToTxtStringOpImpl( - const tReduceInit& op) { - return op.value()->op()->name + "_init"; +std::string ToTxtStringOpImpl(const tReduceInit& op) { + return ToTxtStringOpImpl(op.value()) + "_init"; } -std::string ToTxtStringOpImpl( - const tReduceAcc& op) { - return op.value()->op()->name + "_acc"; +std::string ToTxtStringOpImpl(const tReduceAcc& op) { + return ToTxtStringOpImpl(op.value()) + "_acc"; } std::string ToTxtString(const Op& op) { diff --git a/paddle/cinn/adt/symbolic_dim_infer_ctx.h b/paddle/cinn/adt/symbolic_dim_infer_ctx.h index df704050390e2c..3b37132028c449 100644 --- a/paddle/cinn/adt/symbolic_dim_infer_ctx.h +++ b/paddle/cinn/adt/symbolic_dim_infer_ctx.h @@ -14,9 +14,9 @@ #pragma once -#include "paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h" #include "paddle/cinn/adt/dim_expr.h" -#include "paddle/cinn/hlir/framework/node.h" +#include "paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h" +#include "paddle/pir/core/operation.h" namespace cinn::adt::config { @@ -25,7 +25,7 @@ class SymbolicDimInferCtx { SymbolicDimInferCtx(const SymbolicDimInferCtx&) = delete; SymbolicDimInferCtx(SymbolicDimInferCtx&&) = delete; - SymbolicDimInferCtx(const hlir::framework::Node* node, + SymbolicDimInferCtx(const ::pir::Operation* node, GraphSymbolicDimInferCtx* graph_ctx) : node_(node), graph_ctx_(graph_ctx) {} @@ -38,7 +38,7 @@ class SymbolicDimInferCtx { } const DimExpr& GetInputDimExpr(std::size_t arg_idx, - std::size_t dim_idx) const { + std::size_t dim_idx) const { return graph_ctx_->GetInputDimExpr(node_, arg_idx, dim_idx); } @@ -58,7 +58,7 @@ class SymbolicDimInferCtx { } private: - const hlir::framework::Node* node_; + const ::pir::Operation* node_; GraphSymbolicDimInferCtx* graph_ctx_; }; diff --git a/paddle/cinn/adt/symbolic_dim_infer_util.cc b/paddle/cinn/adt/symbolic_dim_infer_util.cc index 6423fab2a2acf6..0f7e6969b0c8c0 100644 --- a/paddle/cinn/adt/symbolic_dim_infer_util.cc +++ b/paddle/cinn/adt/symbolic_dim_infer_util.cc @@ -15,37 +15,32 @@ #include "paddle/cinn/adt/symbolic_dim_infer_util.h" #include "paddle/cinn/adt/symbolic_dim_infer_ctx.h" -#include "paddle/cinn/common/graph_utils.h" -#include "paddle/cinn/hlir/framework/graph.h" -#include "paddle/cinn/hlir/framework/node.h" +#include "paddle/cinn/hlir/framework/pir/group.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" namespace cinn::adt { +// ADT_TODO : Replace Group with AnalysisManager std::unique_ptr InferSymbolicDim( - const hlir::framework::Graph* graph) { + const cinn::hlir::framework::pir::Group* group) { using InferSymbolicDimFunc = std::function; auto infer_ctx_ptr = - std::make_unique(graph); - - std::vector topo_nodes = - std::get<0>(graph->topological_order()); - for (const common::GraphNode* graph_node : topo_nodes) { - const hlir::framework::Node* op_node = - graph_node->safe_as(); - // if node is NodeData or not op, continue. - if (!op_node || op_node->op() == nullptr) { - continue; - } - - VLOG(1) << "op_name : " << op_node->op()->name; + std::make_unique(group); + + for (const ::pir::Operation* op_node : group->ops) { + VLOG(1) << "op_name : " + << hlir::framework::pir::CompatibleInfo::OpName(*op_node); const auto& infer_symbolic_dim = hlir::framework::Operator::GetAttrs( "infer_symbolic_dim"); - CHECK(infer_symbolic_dim.Find(op_node->op())); + + const hlir::framework::Operator* cinn_op = hlir::framework::Operator::Get( + hlir::framework::pir::CompatibleInfo::OpName(*op_node)); + CHECK(infer_symbolic_dim.Find(cinn_op)); adt::config::SymbolicDimInferCtx ctx{op_node, infer_ctx_ptr.get()}; - infer_symbolic_dim[op_node->op()](&ctx); + infer_symbolic_dim[cinn_op](&ctx); } return infer_ctx_ptr; } diff --git a/paddle/cinn/adt/symbolic_dim_infer_util.h b/paddle/cinn/adt/symbolic_dim_infer_util.h index 30593b2fa06b35..141366358b5b35 100644 --- a/paddle/cinn/adt/symbolic_dim_infer_util.h +++ b/paddle/cinn/adt/symbolic_dim_infer_util.h @@ -16,8 +16,8 @@ #include -namespace cinn::hlir::framework { -class Graph; +namespace cinn::hlir::framework::pir { +struct Group; } namespace cinn::adt { @@ -27,6 +27,6 @@ class GraphSymbolicDimInferCtx; } std::unique_ptr InferSymbolicDim( - const hlir::framework::Graph* graph); + const hlir::framework::pir::Group* group); } // namespace cinn::adt diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc index ba5c946ff31643..9c959c3878e092 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc @@ -18,6 +18,7 @@ #include +#include "paddle/cinn/adt/generate_map_expr.h" #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" @@ -25,9 +26,12 @@ #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" #include "paddle/cinn/hlir/framework/pir_compiler.h" +#include "paddle/cinn/runtime/flags.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/pir/dialect/control_flow/ir/cf_op.h" +PD_DECLARE_bool(cinn_enable_map_expr); + namespace cinn { namespace dialect { namespace ir { @@ -147,8 +151,12 @@ std::unique_ptr CINNGroupLoweringPass(::pir::Program* program) { auto ir_compiler = std::make_shared( *program, target, scope); hlir::framework::PirCompilerManager::Instance().insert(ir_compiler); + adt::TryGenerateMapExprFromGroup(group); auto group1 = std::make_shared(group->ops); + if (FLAGS_cinn_enable_map_expr) { + group1->set_map_expr_ctx(group->mut_map_expr_ctx()); + } auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group1}); std::unordered_map op_attrs{ {cinn::dialect::JitKernelOp::kAttrName, diff --git a/paddle/cinn/hlir/framework/graph.h b/paddle/cinn/hlir/framework/graph.h index 7b360b1ddc2b04..e1d63d9f06f5a8 100644 --- a/paddle/cinn/hlir/framework/graph.h +++ b/paddle/cinn/hlir/framework/graph.h @@ -28,15 +28,6 @@ namespace cinn { -namespace adt { -class MapExprCtx; - -namespace config { -class GraphSymbolicDimInferCtx; -} - -} // namespace adt - namespace hlir { namespace framework { @@ -67,20 +58,6 @@ class Graph : public cinn::common::Graph { /** \brief attributes of a graph */ absl::flat_hash_map> attrs; - void set_graph_ctx( - std::unique_ptr&& graph_ctx) { - CHECK_EQ(this, graph_ctx->graph()); - graph_ctx_ = std::move(graph_ctx); - } - - const adt::config::GraphSymbolicDimInferCtx* graph_ctx() const { - return graph_ctx_.get(); - } - - adt::config::GraphSymbolicDimInferCtx* mut_graph_ctx() { - return graph_ctx_.get(); - } - std::vector> groups; struct Group { Group() = default; @@ -207,17 +184,6 @@ class Graph : public cinn::common::Graph { hlir::framework::OpPatternKind kind() const { return op_pattern_kind; } - adt::MapExprCtx* mut_map_expr_ctx() { return map_expr_ctx_.get(); } - - const adt::MapExprCtx& map_expr_ctx() const { - return *CHECK_NOTNULL(map_expr_ctx_); - } - - void set_map_expr_ctx( - const std::shared_ptr& map_expr_ctx) { - map_expr_ctx_ = map_expr_ctx; - } - private: // input groups std::unordered_set, @@ -229,7 +195,6 @@ class Graph : public cinn::common::Graph { SharedGroupHasher, SharedGroupComparator> consumer_groups_; - std::shared_ptr map_expr_ctx_; }; std::vector> fusion_groups; @@ -337,8 +302,6 @@ class Graph : public cinn::common::Graph { std::vector> FusionGroupsToGroups(); - std::unique_ptr graph_ctx_; - CINN_DISALLOW_COPY_AND_ASSIGN(Graph); }; diff --git a/paddle/cinn/hlir/framework/op_lowering_impl.cc b/paddle/cinn/hlir/framework/op_lowering_impl.cc index d321b94f212a16..e041633b0748bf 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/op_lowering_impl.cc @@ -14,21 +14,17 @@ #include "paddle/cinn/hlir/framework/op_lowering_impl.h" -#include "paddle/cinn/adt/map_expr_ctx.h" #include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/hlir/framework/compile_error.h" #include "paddle/cinn/hlir/framework/graph_compiler_util.h" #include "paddle/cinn/hlir/framework/op_lowering_util.h" #include "paddle/cinn/hlir/op/external_api_registry.h" -#include "paddle/cinn/hlir/pe/map_expr_to_ir.h" #include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" #include "paddle/cinn/runtime/flags.h" PD_DECLARE_bool(cinn_use_cuda_vectorize); -PD_DECLARE_bool(cinn_enable_map_expr); -PD_DECLARE_bool(cinn_map_expr_enable_schedule); PD_DECLARE_bool(cinn_new_group_scheduler); namespace cinn { @@ -103,49 +99,6 @@ bool OpLowererImpl::NonFusibleScheduleDetermineFunction(Node* node) { return true; } -/* Most of below codes copies from `PostProcess` function */ -std::vector OpLowererImpl::LowerMapExpr( - const GroupPtr& group, - const std::unordered_map& tensor_map, - bool do_op_schedule, - bool apply_group_schedule, - bool apply_pass, - std::vector* group_func_arg_tensors) { - if (!FLAGS_cinn_map_expr_enable_schedule) { - do_op_schedule = false; - apply_group_schedule = false; - apply_pass = true; - } - VLOG(1) << "FLAGS_cinn_map_expr_enable_schedule = " - << FLAGS_cinn_map_expr_enable_schedule; - VLOG(1) << "do_op_schedule = " << do_op_schedule; - VLOG(1) << "apply_group_schedule = " << apply_group_schedule; - VLOG(1) << "apply_pass = " << apply_pass; - - ir::Expr func_body = adt::MapExprToIr(group->map_expr_ctx(), target_); - - // 2.Do group schedule. - ir::ModuleExpr mod_expr({func_body}); - ir::IRSchedule ir_sch(mod_expr); - ir_sch.MergeExprs(); - VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - if (apply_group_schedule) { - DoGroupSchedule(ir_sch, group, tensor_map); - VLOG(3) << "After group schedule, ir is: \n" - << ir_sch.GetModule().GetExprs().at(0); - } - - // 3.Do post-processing, - // including preparing function args and temporary variables, - // applying low-level optimization passes, etc. - return PostProcess(group, - tensor_map, - do_op_schedule, - apply_pass, - &ir_sch, - group_func_arg_tensors); -} - std::vector OpLowererImpl::LowerGroup( const GroupPtr& group, bool apply_op_schedule, @@ -169,15 +122,6 @@ std::vector OpLowererImpl::LowerGroup( &group_func_arg_tensors, &tensor_map); - if (FLAGS_cinn_enable_map_expr) { - return LowerMapExpr(group, - tensor_map, - /*do_op_schedule=*/do_op_schedule, - /*apply_group_schedule=*/apply_group_schedule, - /*apply_pass=*/apply_pass, - &group_func_arg_tensors); - } - // 2.Do group schedule. ir::ModuleExpr mod_expr(func_bodies); ir::IRSchedule ir_sch(mod_expr); @@ -480,9 +424,6 @@ std::vector OpLowererImpl::DoOpLower( for (auto fun : funcs) { VLOG(4) << fun; } - if (FLAGS_cinn_enable_map_expr) { - group->mut_map_expr_ctx()->UpdateOpLoweredFuncKey(node, funcs); - } op_func_arg_tensors->clear(); for (int idx = 0; idx < pack.size() - 1; ++idx) { diff --git a/paddle/cinn/hlir/framework/op_lowering_impl.h b/paddle/cinn/hlir/framework/op_lowering_impl.h index 3c458b9ecc11eb..0ff9cce38bf51c 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/op_lowering_impl.h @@ -102,24 +102,6 @@ class OpLowererImpl : public OpLowererImplBase { ir::IRSchedule* ir_sch, std::vector* group_func_arg_tensors); - /** - * @brief Generate MapExpr and Lower it to std::vector - * @param group The group to be lowered. - * @param tensor_map All tensors used for calculating the group. - * @param done_op_schedule Mark whether the Op level schedule has been - * applied. - * @param apply_group_schedule Whether to schedule at group level. - * @param group_func_arg_tensors Tensors used as the group function arguments. - * @return The lowered funcs after the post processing. - */ - std::vector LowerMapExpr( - const GroupPtr& group, - const std::unordered_map& tensor_map, - bool done_op_schedule, - bool apply_group_schedule, - bool apply_pass, - std::vector* group_func_arg_tensors); - /** * @brief Lower an Op set to CINN IR. * Compute, Lower and optional Schedule will be performed one by one diff --git a/paddle/cinn/hlir/framework/pir/group.h b/paddle/cinn/hlir/framework/pir/group.h index 7b0913525c2545..2e7e05c869be36 100644 --- a/paddle/cinn/hlir/framework/pir/group.h +++ b/paddle/cinn/hlir/framework/pir/group.h @@ -21,6 +21,16 @@ #include "paddle/pir/core/operation.h" namespace cinn { + +namespace adt { + +namespace config { +class GraphSymbolicDimInferCtx; +} + +class MapExprCtx; +} // namespace adt + namespace hlir { namespace framework { namespace pir { @@ -30,6 +40,8 @@ using framework::OpPatternKind; struct Group { public: Group() = default; + Group(const Group&) = delete; + Group(Group&&) = delete; explicit Group(const std::vector<::pir::Operation*>& group_ops) : ops(group_ops) {} @@ -81,7 +93,7 @@ struct Group { } }; - std::vector<::pir::Operation*> CollectOps() { + std::vector<::pir::Operation*> CollectOps() const { if (fused_sub_groups.size()) { std::vector<::pir::Operation*> tmp_ops; for (auto& group : fused_sub_groups) { @@ -107,7 +119,7 @@ struct Group { } } - std::unordered_set<::pir::Operation*> OpSet() { + std::unordered_set<::pir::Operation*> OpSet() const { std::unordered_set<::pir::Operation*> op_set; for (auto op : CollectOps()) { op_set.insert(op); @@ -115,7 +127,7 @@ struct Group { return op_set; } - std::unordered_set<::pir::Value> GetInputOpValues() { + std::unordered_set<::pir::Value> GetInputOpValues() const { std::unordered_set<::pir::Value> group_inputs; auto ops_set = this->OpSet(); // count all op's input Value @@ -144,7 +156,7 @@ struct Group { return group_inputs; } - std::unordered_set<::pir::Value> GetOutputOpValues() { + std::unordered_set<::pir::Value> GetOutputOpValues() const { std::unordered_set<::pir::Value> group_outputs; for (auto op : this->output_ops) { @@ -161,6 +173,35 @@ struct Group { std::string GetFuncName() { return "fn_" + group_id + unique_id; } + std::shared_ptr mut_map_expr_ctx() { + CHECK_NOTNULL(map_expr_ctx_); + return map_expr_ctx_; + } + + const adt::MapExprCtx& map_expr_ctx() const { + return *CHECK_NOTNULL(map_expr_ctx_); + } + + void set_map_expr_ctx(const std::shared_ptr& map_expr_ctx) { + map_expr_ctx_ = map_expr_ctx; + } + + void set_graph_symbolic_dim_infer_ctx( + std::unique_ptr&& + graph_symbolic_dim_infer_ctx) { + CHECK_EQ(this, graph_symbolic_dim_infer_ctx->graph()); + graph_symbolic_dim_infer_ctx_ = std::move(graph_symbolic_dim_infer_ctx); + } + + const adt::config::GraphSymbolicDimInferCtx* graph_symbolic_dim_infer_ctx() + const { + return graph_symbolic_dim_infer_ctx_.get(); + } + + adt::config::GraphSymbolicDimInferCtx* mut_graph_symbolic_dim_infer_ctx() { + return graph_symbolic_dim_infer_ctx_.get(); + } + public: const std::unordered_set, SharedGroupHasher, @@ -211,6 +252,9 @@ struct Group { SharedGroupHasher, SharedGroupComparator> consumer_groups_; + std::shared_ptr map_expr_ctx_; + std::unique_ptr + graph_symbolic_dim_infer_ctx_; }; } // namespace pir diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index 40cd6444a4fed9..fd7baa18a557f8 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -16,19 +16,23 @@ #include +#include "paddle/cinn/adt/map_expr_ctx.h" #include "paddle/cinn/ast_gen_ius/tensor_group.h" +#include "paddle/cinn/hlir/framework/compile_error.h" #include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/hlir/op/external_api_registry.h" +#include "paddle/cinn/hlir/pe/map_expr_to_ir.h" +#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" -#include "paddle/cinn/optim/transform_gpu_forloop.h" - -#include "paddle/cinn/hlir/framework/compile_error.h" -#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/lang/placeholder.h" +#include "paddle/cinn/optim/transform_gpu_forloop.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/phi/core/ddim.h" PD_DECLARE_bool(cinn_use_cuda_vectorize); +PD_DECLARE_bool(cinn_enable_map_expr); +PD_DECLARE_bool(cinn_enable_map_expr_schedule); namespace cinn { namespace hlir { @@ -182,11 +186,58 @@ bool OpLowererImpl::NonFusibleScheduleDetermineFunction(::pir::Operation* op) { return true; } +/* Most of below codes copies from `PostProcess` function */ +std::vector OpLowererImpl::LowerMapExpr( + const GroupPtr& group, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, + const std::unordered_map& tmp_tensor_info, + bool do_op_schedule, + bool apply_group_schedule, + std::vector* group_func_arg_tensors) { + VLOG(1) << "do_op_schedule = " << do_op_schedule; + VLOG(1) << "apply_group_schedule = " << apply_group_schedule; + + VLOG(1) << "Begin MapExprToIr"; + ir::Expr func_body = adt::MapExprToIr(group->map_expr_ctx(), target_); + + // 2.Do group schedule. + ir::ModuleExpr mod_expr({func_body}); + ir::IRSchedule ir_sch(mod_expr); + ir_sch.MergeExprs(); + VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + if (apply_group_schedule) { + std::unordered_set output_tensor_names; + std::transform( + group->output_ops.begin(), + group->output_ops.end(), + std::inserter(output_tensor_names, output_tensor_names.begin()), + [](::pir::Operation* node) { + ::pir::Value node_data = node->result(0); + return hlir::framework::pir::CompatibleInfo::ValueName(node_data); + }); + ir::StaticShapeGroupScheduler group_scheduler( + &ir_sch, output_tensor_names, target_); + group_scheduler.MapExprSchedule(); + VLOG(3) << "After group schedule, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); + } + + // 3.Do post-processing, + // including preparing function args and temporary variables, + // applying low-level optimization passes, etc. + return PostProcess( + group, tensor_map, do_op_schedule, &ir_sch, group_func_arg_tensors); +} + std::vector OpLowererImpl::LowerGroup( const GroupPtr& group, bool apply_op_schedule, bool apply_group_schedule, ScheduleDetermineFunction schedule_determine_func) { + if (FLAGS_cinn_enable_map_expr && FLAGS_cinn_enable_map_expr_schedule) { + apply_op_schedule = false; + apply_group_schedule = false; + } // 1.Do compute, lower and schedule for each op. auto& ops = group->ops; if (ops.size() == 1 && ops[0]->name() == "custom_call") { @@ -198,12 +249,21 @@ std::vector OpLowererImpl::LowerGroup( // XX_0, XX_1, so we log them in tmp_tensor_info; std::unordered_map tmp_tensor_info; bool do_op_schedule = apply_group_schedule || apply_op_schedule; - std::vector func_bodies = LowerOps(ops, + std::vector func_bodies = LowerOps(group, + ops, do_op_schedule, schedule_determine_func, &group_func_arg_tensors, &tensor_map, &tmp_tensor_info); + if (FLAGS_cinn_enable_map_expr) { + return LowerMapExpr(group, + tensor_map, + tmp_tensor_info, + /*do_op_schedule=*/do_op_schedule, + /*apply_group_schedule=*/apply_group_schedule, + &group_func_arg_tensors); + } // 2.Do group schedule. ir::ModuleExpr mod_expr(func_bodies); @@ -348,6 +408,7 @@ std::vector OpLowererImpl::PostProcess( } std::vector OpLowererImpl::LowerOps( + const GroupPtr& group, const std::vector<::pir::Operation*>& ops, bool apply_op_schedule, ScheduleDetermineFunction schedule_determine_func, @@ -374,7 +435,7 @@ std::vector OpLowererImpl::LowerOps( node_attrs, op_func_arg_tensors, out_types, out_shapes, this->target_)); // 2.Perform the lower process of Op std::vector funcs = DoOpLower( - op_impl, op, tensor_map, tmp_tensor_info, &op_func_arg_tensors); + group, op_impl, op, tensor_map, tmp_tensor_info, &op_func_arg_tensors); if (apply_op_schedule && (this->*schedule_determine_func)(op)) { // 3.Perform the schedule of Op @@ -393,6 +454,7 @@ std::vector OpLowererImpl::LowerOps( } std::vector OpLowererImpl::DoOpLower( + const GroupPtr& group, std::shared_ptr op_impl, ::pir::Operation* op, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map, @@ -460,6 +522,9 @@ std::vector OpLowererImpl::DoOpLower( VLOG(4) << fun; } } + if (FLAGS_cinn_enable_map_expr) { + group->mut_map_expr_ctx()->UpdateOpLoweredFuncKey(op, funcs); + } op_func_arg_tensors->clear(); for (int idx = 0; idx < pack.size() - 1; ++idx) { diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.h b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h index 156e7a399ced51..dd31233591b3d5 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h @@ -102,10 +102,29 @@ class OpLowererImpl : public OpLowererImplBase { ir::IRSchedule* ir_sch, std::vector* group_func_arg_tensors); + /** + * @brief Generate MapExpr and Lower it to std::vector + * @param group The group to be lowered. + * @param tensor_map All tensors used for calculating the group. + * @param done_op_schedule Mark whether the Op level schedule has been + * applied. + * @param apply_group_schedule Whether to schedule at group level. + * @param group_func_arg_tensors Tensors used as the group function arguments. + * @return The lowered funcs after the post processing. + */ + std::vector LowerMapExpr( + const GroupPtr& group, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, + const std::unordered_map& tmp_tensor_info, + bool done_op_schedule, + bool apply_group_schedule, + std::vector* group_func_arg_tensors); + /** * @brief Lower an Op set to CINN IR. * Compute, Lower and optional Schedule will be performed one by one * for each Op. + * @param group The group to be lowered. * @param ops The Op to be lowered. * @param apply_op_schedule Whether to schedule at Op level. * @param schedule_determine_func Function used to determine which Ops to @@ -115,6 +134,7 @@ class OpLowererImpl : public OpLowererImplBase { * @return The lowered func bodies of Op set. */ std::vector LowerOps( + const GroupPtr& group, const std::vector<::pir::Operation*>& ops, bool apply_op_schedule, ScheduleDetermineFunction schedule_determine_func, @@ -125,6 +145,7 @@ class OpLowererImpl : public OpLowererImplBase { /** * @brief Lower an Op to CINN IR. The Compute and Lower processes will be * called sequentially. + * @param group The group to be lowered. * @param op_impl The Op implementation defining Compute and Schedule. * @param op The Op to be lowered. * @param tensor_map All tensors used for calculating the group. @@ -132,6 +153,7 @@ class OpLowererImpl : public OpLowererImplBase { * @return The lowered func of the Op. */ std::vector DoOpLower( + const GroupPtr& group, std::shared_ptr op_impl, ::pir::Operation* op, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map, diff --git a/paddle/cinn/hlir/pe/map_expr_to_ir.cc b/paddle/cinn/hlir/pe/map_expr_to_ir.cc index 66921b9d9ee9da..5a5e7bccff0c28 100644 --- a/paddle/cinn/hlir/pe/map_expr_to_ir.cc +++ b/paddle/cinn/hlir/pe/map_expr_to_ir.cc @@ -22,11 +22,15 @@ #include "paddle/cinn/adt/map_expr.h" #include "paddle/cinn/adt/map_expr_ctx.h" #include "paddle/cinn/adt/match.h" +#include "paddle/cinn/adt/no_inline_translator.h" #include "paddle/cinn/adt/print.h" #include "paddle/cinn/common/target.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/runtime/flags.h" + +PD_DECLARE_bool(cinn_enable_map_expr_inline); namespace cinn::adt { @@ -34,7 +38,7 @@ namespace { using IteratorInt = std::int32_t; using Node2LoweredFuncs = - std::unordered_map>; + std::unordered_map<::pir::Operation*, std::vector>; using TensorIteratorExpr4TensorT = std::function(const adt::Tensor&)>; @@ -65,9 +69,9 @@ class MapExprToIrTranslator { } private: - ir::Expr GetStoreExprForOp(const hlir::framework::Node* op) const { + ir::Expr GetStoreExprForOp(const ::pir::Operation* op) const { const auto& iter = - node2lowered_funcs_->find(const_cast(op)); + node2lowered_funcs_->find(const_cast<::pir::Operation*>(op)); CHECK(iter != node2lowered_funcs_->end()); const auto& lowered_funcs = iter->second; CHECK_EQ(lowered_funcs.size(), 1); @@ -81,9 +85,9 @@ class MapExprToIrTranslator { } ir::Expr GetStoreExprForOp( - tReduceInit op) const { - const auto& iter = node2lowered_funcs_->find( - const_cast(op.value())); + const tReduceInit& op) const { + const auto& iter = + node2lowered_funcs_->find(const_cast<::pir::Operation*>(op.value())); CHECK(iter != node2lowered_funcs_->end()); const auto& lowered_funcs = iter->second; CHECK_EQ(lowered_funcs.size(), 1); @@ -96,9 +100,9 @@ class MapExprToIrTranslator { } ir::Expr GetStoreExprForOp( - tReduceAcc op) const { - const auto& iter = node2lowered_funcs_->find( - const_cast(op.value())); + const tReduceAcc& op) const { + const auto& iter = + node2lowered_funcs_->find(const_cast<::pir::Operation*>(op.value())); CHECK(iter != node2lowered_funcs_->end()); const auto& lowered_funcs = iter->second; CHECK_EQ(lowered_funcs.size(), 1); @@ -209,7 +213,12 @@ class MapExprToIrTranslator { } InlineStmt ConvertToInlineStmt(const InternalStmt& internal_stmt) const { - return InlineTranslator::Call(internal_stmt); + if (FLAGS_cinn_enable_map_expr_inline) { + return InlineTranslator::Call(internal_stmt); + } else { + return NoInlineTranslator::Call(internal_stmt); + } + LOG(FATAL) << "Dead code"; } std::optional TranslateOpExprImpl( @@ -242,8 +251,61 @@ class MapExprToIrTranslator { op_expr.variant()); } + std::optional MakeLoadExpr( + const ir::Expr& input_expr, + const List& op_expr_children, + const IterExprs4TensorT& IterExprs4Tensor) const { + ir::Expr store_rvalue = ir::ir_utils::IRCopy(input_expr); + CHECK_EQ(store_rvalue->operands.size(), 0); + CHECK_EQ(op_expr_children->size(), 1); + store_rvalue.As()->indices = + TranslateTensorIndex(op_expr_children->at(0), IterExprs4Tensor); + return store_rvalue; + } + + std::optional MakeCallExpr( + const ir::Expr& input_expr, + const List& op_expr_children, + const IterExprs4TensorT& IterExprs4Tensor) const { + ir::Expr store_rvalue = ir::ir_utils::IRCopy(input_expr); + CHECK_EQ(store_rvalue->operands.size(), 0); + CHECK(!op_expr_children->empty()); + CHECK_EQ((store_rvalue.As()->read_args.size()), + (op_expr_children->size())); + for (int i = 0; i < op_expr_children->size(); ++i) { + const auto& opt_operant = TranslateOpExpr( + op_expr_children->at(i), std::nullopt, IterExprs4Tensor); + if (opt_operant.has_value()) { + store_rvalue.As()->read_args.at(i) = opt_operant.value(); + } else { + store_rvalue.As()->read_args.at(i).As()->indices = + TranslateTensorIndex(op_expr_children->at(i), IterExprs4Tensor); + } + } + return store_rvalue; + } + + std::optional MakeGeneralExpr( + const ir::Expr& input_expr, + const List& op_expr_children, + const IterExprs4TensorT& IterExprs4Tensor) const { + ir::Expr store_rvalue = ir::ir_utils::IRCopy(input_expr); + CHECK_EQ(store_rvalue->operands.size(), op_expr_children->size()); + for (int i = 0; i < op_expr_children->size(); ++i) { + const auto& opt_operant = TranslateOpExpr( + op_expr_children->at(i), std::nullopt, IterExprs4Tensor); + if (opt_operant.has_value()) { + store_rvalue->operands.at(i) = opt_operant.value(); + } else { + store_rvalue->operands.at(i).As()->indices = + TranslateTensorIndex(op_expr_children->at(i), IterExprs4Tensor); + } + } + return store_rvalue; + } + std::optional TranslateOpCallImpl( - const hlir::framework::Node* op, + const ::pir::Operation* op, const OpCall& op_expr, const std::optional& opt_output_tensor, const IterExprs4TensorT& IterExprs4Tensor) const { @@ -253,23 +315,13 @@ class MapExprToIrTranslator { ir::Expr store_rvalue = store_expr.value().As()->value; if (store_rvalue.As()) { - CHECK_EQ(store_rvalue->operands.size(), 0); - CHECK_EQ(op_expr_children->size(), 1); - store_rvalue.As()->indices = - TranslateTensorIndex(op_expr_children->at(0), IterExprs4Tensor); + return MakeLoadExpr(store_rvalue, op_expr_children, IterExprs4Tensor); + } else if (store_rvalue.As()) { + return MakeCallExpr(store_rvalue, op_expr_children, IterExprs4Tensor); } else { if (!op_expr_children->empty()) { - CHECK_EQ(store_rvalue->operands.size(), op_expr_children->size()); - for (int i = 0; i < op_expr_children->size(); ++i) { - const auto& opt_operant = TranslateOpExpr( - op_expr_children->at(i), std::nullopt, IterExprs4Tensor); - if (opt_operant.has_value()) { - store_rvalue->operands.at(i) = opt_operant.value(); - } else { - store_rvalue->operands.at(i).As()->indices = - TranslateTensorIndex(op_expr_children->at(i), IterExprs4Tensor); - } - } + return MakeGeneralExpr( + store_rvalue, op_expr_children, IterExprs4Tensor); } else { // Do nothing } @@ -278,7 +330,7 @@ class MapExprToIrTranslator { } std::optional TranslateOpCallImpl( - const tReduceInit& op, + const tReduceInit& op, const OpCall& op_expr, const std::optional& opt_output_tensor, const IterExprs4TensorT& IterExprs4Tensor) const { @@ -294,7 +346,7 @@ class MapExprToIrTranslator { } std::optional TranslateOpCallImpl( - const tReduceAcc& op, + const tReduceAcc& op, const OpCall& op_expr, const std::optional& opt_output_tensor, const IterExprs4TensorT& IterExprs4Tensor) const { diff --git a/paddle/cinn/hlir/pe/map_expr_to_ir.h b/paddle/cinn/hlir/pe/map_expr_to_ir.h index c89faac13d294f..18d954bb34d3d2 100644 --- a/paddle/cinn/hlir/pe/map_expr_to_ir.h +++ b/paddle/cinn/hlir/pe/map_expr_to_ir.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/cinn/adt/map_expr.h" +#include "paddle/cinn/adt/map_expr_ctx.h" #include "paddle/cinn/ir/ir.h" namespace cinn::common { diff --git a/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc index 8c2ae6a6799c9e..21ef03bd6d5b36 100644 --- a/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc @@ -143,6 +143,13 @@ void StaticShapeGroupScheduler::Schedule() { #endif } +void StaticShapeGroupScheduler::MapExprSchedule() { + DoComputeInline(); +#ifdef CINN_WITH_CUDA + AllocateStorage(); +#endif +} + std::vector> StaticShapeGroupScheduler::GetIRs() { return {{Expr(1), ir_sch_->GetModule().GetExprs()[0]}}; diff --git a/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h b/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h index b2b89c392bdc05..81d71a853dbfd5 100644 --- a/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h +++ b/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h @@ -51,6 +51,8 @@ class StaticShapeGroupScheduler : public GroupScheduler { void Schedule() override; + void MapExprSchedule(); + std::vector> GetIRs() override; private: diff --git a/paddle/cinn/pybind/frontend.cc b/paddle/cinn/pybind/frontend.cc index 3fa0abc78f8fd2..aafa9bedf40d07 100644 --- a/paddle/cinn/pybind/frontend.cc +++ b/paddle/cinn/pybind/frontend.cc @@ -203,8 +203,6 @@ void BindFrontend(pybind11::module *m) { auto graph = Optimize(&self, fetch_ids, target, passes); - cinn::adt::TryGenerateMapExprFromGraph(graph); - scope = hlir::framework::BuildScope(target, graph, scope); hlir::framework::CompilationContext context(graph, scope, target); diff --git a/paddle/cinn/runtime/flags.cc b/paddle/cinn/runtime/flags.cc index 3e06ede9b31530..16fabc434b5a10 100644 --- a/paddle/cinn/runtime/flags.cc +++ b/paddle/cinn/runtime/flags.cc @@ -79,10 +79,12 @@ PD_DEFINE_bool(cinn_enable_map_expr, BoolFromEnv("FLAGS_cinn_enable_map_expr", false), "It controls whether to use cinn with map_expr"); -PD_DEFINE_bool( - cinn_map_expr_enable_schedule, - BoolFromEnv("FLAGS_cinn_map_expr_enable_schedule", false), - "It controls whether to use schedule and pass when enables map_expr"); +PD_DEFINE_bool(cinn_enable_map_expr_schedule, + BoolFromEnv("FLAGS_cinn_enable_map_expr_schedule", true), + "It controls whether to schedule by map_expr"); +PD_DEFINE_bool(cinn_enable_map_expr_inline, + BoolFromEnv("FLAGS_cinn_enable_map_expr_inline", false), + "It controls whether to inline by map_expr"); PD_DEFINE_bool( cinn_map_expr_enable_dynamic_shape, diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index 34e43ac5d71ee6..47bc076ca689aa 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -321,8 +321,6 @@ std::unique_ptr CinnCompiler::CompileGraph( << target.arch_str() << "), and its related graph:\n" << cinn_graph->Visualize(); - cinn::adt::TryGenerateMapExprFromGraph(cinn_graph); - auto scope = BuildScope(target, cinn_graph); CompilationContext context(cinn_graph, scope, target); context.with_instantiate_variables = false; diff --git a/test/cinn/CMakeLists.txt b/test/cinn/CMakeLists.txt index 3a25d7e1ef182a..3158c4372d8fdb 100644 --- a/test/cinn/CMakeLists.txt +++ b/test/cinn/CMakeLists.txt @@ -310,27 +310,4 @@ if(WITH_GPU) WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) endforeach() - # adt test - file( - GLOB CINN_ADT_TEST - RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" - "adt/test_*.py") - set(EXCLUDE_ADT_TEST test_add_inline) - - foreach(adt_test_name ${EXCLUDE_ADT_TEST}) - list(REMOVE_ITEM CINN_ADT_TEST adt/${adt_test_name}.py) - endforeach() - - foreach(adt_test_name ${CINN_ADT_TEST}) - string(REGEX REPLACE ".py" "" adt_test_name ${adt_test_name}) - add_test( - NAME ${adt_test_name} - COMMAND - ${CMAKE_COMMAND} -E env - PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH} - FLAGS_cinn_enable_map_expr=True python3 - ${CMAKE_CURRENT_SOURCE_DIR}/${adt_test_name}.py - WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) - endforeach() - endif() diff --git a/test/cinn/adt/test_add_inline.py b/test/cinn/adt/test_add_inline.py deleted file mode 100755 index df92e86677abcc..00000000000000 --- a/test/cinn/adt/test_add_inline.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 CINN Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from cinn.common import DefaultNVGPUTarget, Float -from cinn.frontend import NetBuilder - - -class TestMapExprAddFusion(unittest.TestCase): - def setUp(self): - self.inputs = { - "x": np.random.uniform(-1.0, 1.0, [1024, 1024]).astype("float32"), - "y": np.random.uniform(-1.0, 1.0, [1024, 1024]).astype("float32"), - "z": np.random.uniform(-1.0, 1.0, [1024, 1024]).astype("float32"), - } - - def test_add_fusion(self): - builder = NetBuilder("TestMapExprAddFusion") - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") - y = builder.create_input(Float(32), self.inputs["y"].shape, "y") - z = builder.create_input(Float(32), self.inputs["z"].shape, "z") - - a = builder.elementwise_add(x, y) - out = builder.elementwise_add(a, z) - - prog = builder.build() - target = DefaultNVGPUTarget() - result = prog.build_and_get_output( - target, - [x, y, z], - [self.inputs["x"], self.inputs["y"], self.inputs["z"]], - [out], - passes=[], - scope=None, - ) - - np.testing.assert_allclose( - result[0].numpy(target), - self.inputs["x"] + self.inputs["y"] + self.inputs["z"], - err_msg="TestMapExprAddFusion failed!", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/cinn/adt/test_broadcast_expr.py b/test/cinn/adt/test_broadcast_expr.py deleted file mode 100755 index a9caea37c6d064..00000000000000 --- a/test/cinn/adt/test_broadcast_expr.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 CINN Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from cinn.common import DefaultNVGPUTarget, Float -from cinn.frontend import NetBuilder - - -class TestMapExprBroadcast(unittest.TestCase): - def setUp(self): - self.inputs = { - "x1": np.random.uniform(-1.0, 1.0, [4, 16]).astype("float32"), - "x2": np.random.uniform(-1.0, 1.0, [16]).astype("float32"), - } - - def test_broadcast(self): - builder = NetBuilder("TestMapExprBroadcast") - x1 = builder.create_input(Float(32), self.inputs["x1"].shape, "x1") - x2 = builder.create_input(Float(32), self.inputs["x2"].shape, "x2") - z = builder.elementwise_add(x1, x2) - out = builder.relu(z) - prog = builder.build() - - target = DefaultNVGPUTarget() - - result = prog.build_and_get_output( - target, - [x1, x2], - [self.inputs["x1"], self.inputs["x2"]], - [out], - passes=[], - scope=None, - ) - - np.testing.assert_allclose( - result[0].numpy(target), - np.maximum((self.inputs["x1"] + self.inputs["x2"]), 0), - err_msg="TestMapExprBroadcast failed!", - ) - print("Finish Test") - - -if __name__ == "__main__": - unittest.main() diff --git a/test/cinn/adt/test_fusion_ability.py b/test/cinn/adt/test_fusion_ability.py deleted file mode 100644 index 0bf40af0fc34ff..00000000000000 --- a/test/cinn/adt/test_fusion_ability.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 CINN Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from cinn.common import DefaultNVGPUTarget, Float -from cinn.frontend import NetBuilder - - -class TestMapExprReduceFusion(unittest.TestCase): - def setUp(self): - self.inputs = { - "x": np.random.uniform(-1.0, 1.0, [2, 1024]).astype("float32"), - "y": np.random.uniform(-1.0, 1.0, [2, 1024]).astype("float32"), - } - - def test_reduce_fusion(self): - builder = NetBuilder("TestMapExprReduceFusion") - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") - y = builder.create_input(Float(32), self.inputs["y"].shape, "y") - - t = builder.elementwise_add(x, y) - t = builder.relu(t) - t = builder.elementwise_add(t, y) - out = builder.reduce_sum(t, [0], False) - prog = builder.build() - - target = DefaultNVGPUTarget() - result = prog.build_and_get_output( - target, - [x, y], - [self.inputs["x"], self.inputs["y"]], - [out], - passes=[], - scope=None, - ) - - np_expect = self.inputs["x"] + self.inputs["y"] - np_expect = np.maximum(np_expect, 0) - np_expect = np_expect + self.inputs["y"] - np_expect = np.sum(np_expect, axis=0) - - np.testing.assert_allclose( - result[0].numpy(target), - np_expect, - err_msg="TestMapExprReduceFusion failed!", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/cinn/adt/test_naive_add.py b/test/cinn/adt/test_naive_add.py deleted file mode 100755 index 304978ae5ceab5..00000000000000 --- a/test/cinn/adt/test_naive_add.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 CINN Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from cinn.common import DefaultNVGPUTarget, Float -from cinn.frontend import NetBuilder - - -class TestMapExprNaiveAdd(unittest.TestCase): - def setUp(self): - self.inputs = { - "x": np.random.uniform(-1.0, 1.0, [1024, 1024]).astype("float32"), - "y": np.random.uniform(-1.0, 1.0, [1024, 1024]).astype("float32"), - } - - def test_naive_add(self): - builder = NetBuilder("TestMapExprNaiveAdd") - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") - y = builder.create_input(Float(32), self.inputs["y"].shape, "y") - - out = builder.elementwise_add(x, y) - prog = builder.build() - target = DefaultNVGPUTarget() - - result = prog.build_and_get_output( - target, - [x, y], - [self.inputs["x"], self.inputs["y"]], - [out], - passes=[], - scope=None, - ) - - np.testing.assert_allclose( - result[0].numpy(target), - self.inputs["x"] + self.inputs["y"], - err_msg="TestMapExprNaiveAdd failed!", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/cinn/adt/test_naive_reduce.py b/test/cinn/adt/test_naive_reduce.py deleted file mode 100644 index 217ef7695f539e..00000000000000 --- a/test/cinn/adt/test_naive_reduce.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 CINN Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from cinn.common import DefaultNVGPUTarget, Float -from cinn.frontend import NetBuilder - - -class TestMapExprNaiveReduce(unittest.TestCase): - def setUp(self): - self.inputs = { - "x": np.random.uniform(-1.0, 1.0, [2, 1024, 1024]).astype("float32") - } - - def test_naive_reduce(self): - builder = NetBuilder("TestMapExprNaiveReduce") - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") - - out = builder.reduce_sum(x, [0], False) - prog = builder.build() - target = DefaultNVGPUTarget() - - result = prog.build_and_get_output( - target, [x], [self.inputs["x"]], [out], passes=[], scope=None - ) - - np.testing.assert_allclose( - result[0].numpy(target), - np.sum(self.inputs["x"], axis=0), - err_msg="TestMapExprNaiveReduce failed!", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/cinn/adt/test_reduce_fusion.py b/test/cinn/adt/test_reduce_fusion.py deleted file mode 100644 index 2432e53e21bcc5..00000000000000 --- a/test/cinn/adt/test_reduce_fusion.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 CINN Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from cinn.common import DefaultNVGPUTarget, Float -from cinn.frontend import NetBuilder - - -class TestMapExprReduceFusion(unittest.TestCase): - def setUp(self): - self.inputs = { - "x": np.random.uniform(-1.0, 1.0, [2, 1024, 1024]).astype( - "float32" - ), - "y": np.random.uniform(-1.0, 1.0, [2, 1024, 1024]).astype( - "float32" - ), - } - - def test_reduce_fusion(self): - builder = NetBuilder("TestMapExprReduceFusion") - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") - y = builder.create_input(Float(32), self.inputs["y"].shape, "y") - - t = builder.elementwise_add(x, y) - out = builder.reduce_sum(t, [0], False) - prog = builder.build() - - target = DefaultNVGPUTarget() - result = prog.build_and_get_output( - target, - [x, y], - [self.inputs["x"], self.inputs["y"]], - [out], - passes=[], - scope=None, - ) - - np.testing.assert_allclose( - result[0].numpy(target), - np.sum(self.inputs["x"] + self.inputs["y"], axis=0), - err_msg="TestMapExprReduceFusion failed!", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/cinn/adt/test_reduce_schedule_mesh.py b/test/cinn/adt/test_reduce_schedule_mesh.py deleted file mode 100644 index 10bde8159932b5..00000000000000 --- a/test/cinn/adt/test_reduce_schedule_mesh.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 CINN Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from cinn.common import DefaultNVGPUTarget, Float -from cinn.frontend import NetBuilder - - -class TestMapExprReduceScheduleMesh(unittest.TestCase): - def setUp(self): - self.inputs = { - "x": np.random.uniform(-1.0, 1.0, [32, 2048]).astype("float32"), - "y": np.random.uniform(-1.0, 1.0, [32, 2048]).astype("float32"), - } - - def test_schedule_mesh(self): - builder = NetBuilder("TestMapExprReduceScheduleMesh") - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") - y = builder.create_input(Float(32), self.inputs["y"].shape, "y") - - t = builder.elementwise_add(x, y) - out = builder.reduce_sum(t, [0], False) - prog = builder.build() - - target = DefaultNVGPUTarget() - - result = prog.build_and_get_output( - target, - [x, y], - [self.inputs["x"], self.inputs["y"]], - [out], - passes=[], - scope=None, - ) - - np.testing.assert_allclose( - result[0].numpy(target), - np.sum(self.inputs["x"] + self.inputs["y"], axis=0), - err_msg="TestMapExprReduceScheduleMesh failed!", - ) - - -if __name__ == "__main__": - unittest.main() From bb259a26db6724a092026ea2e7c48079503d796f Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Fri, 10 Nov 2023 01:35:59 +0000 Subject: [PATCH 2/9] Add non-inline translator --- paddle/cinn/adt/inline_translator.h | 37 +-------- paddle/cinn/adt/inline_translator_trait.h | 58 ++++++++++++++ paddle/cinn/adt/naive_op_equation_context.cc | 2 +- paddle/cinn/adt/no_inline_translator.h | 83 ++++++++++++++++++++ paddle/cinn/adt/symbolic_dim_infer_ctx.h | 3 +- paddle/cinn/hlir/framework/pir/group.h | 8 +- 6 files changed, 147 insertions(+), 44 deletions(-) create mode 100644 paddle/cinn/adt/inline_translator_trait.h create mode 100644 paddle/cinn/adt/no_inline_translator.h diff --git a/paddle/cinn/adt/inline_translator.h b/paddle/cinn/adt/inline_translator.h index 9560750bd2882e..d3910791f32b04 100644 --- a/paddle/cinn/adt/inline_translator.h +++ b/paddle/cinn/adt/inline_translator.h @@ -15,47 +15,12 @@ #pragma once #include "paddle/cinn/adt/adt.h" +#include "paddle/cinn/adt/inline_translator_trait.h" #include "paddle/cinn/adt/map_expr.h" #include "paddle/cinn/adt/tree.h" namespace cinn::adt { -template