diff --git a/paddle/cinn/adt/CMakeLists.txt b/paddle/cinn/adt/CMakeLists.txt index 0997562ca548b1..64ab23cd9d638c 100644 --- a/paddle/cinn/adt/CMakeLists.txt +++ b/paddle/cinn/adt/CMakeLists.txt @@ -1,45 +1,56 @@ -core_gather_headers() +if(NOT CINN_ONLY) + core_gather_headers() -gather_srcs( - cinnapi_src - SRCS - anchor_sd_equation_context.cc - equation_function.cc - equation_solver.cc - equation_value.cc - generate_map_expr.cc - get_sub_reshape_dim_ranges.cc - graph_symbolic_dim_infer_ctx.cc - igroup.cc - index_expr_infer_context.cc - kgroup.cc - m_ir.cc - naive_bidirection_equation_generator.cc - naive_op_equation_context.cc - partition_op_stmts.cc - print_equations.cc - print_map_expr.cc - print_schedule_descriptor.cc - print_schedule_dim.cc - print_schedule_mesh.cc - print_value.cc - schedule_descriptor.cc - schedule_dim.cc - schedule_mesh.cc - dim_expr.cc - dim_expr_test.cc - dim_expr_simplifier.cc - symbolic_dim_infer_util.cc - simplify_value.cc - write_broadcast_disabled_bidirection_equation_generator.cc - print_dim_expr.cc) + gather_srcs( + cinnapi_src + SRCS + adapter_tensor.cc + anchor_sd_equation_context.cc + equation_function.cc + equation_solver.cc + equation_value.cc + generate_map_expr.cc + get_sub_reshape_dim_ranges.cc + graph_symbolic_dim_infer_ctx.cc + igroup.cc + index_expr_infer_context.cc + kgroup.cc + m_ir.cc + naive_bidirection_equation_generator.cc + naive_op_equation_context.cc + partition_op_stmts.cc + print_equations.cc + print_map_expr.cc + print_schedule_descriptor.cc + print_schedule_dim.cc + print_schedule_mesh.cc + print_value.cc + schedule_descriptor.cc + schedule_dim.cc + schedule_mesh.cc + dim_expr.cc + dim_expr_simplifier.cc + simplify_value.cc + write_broadcast_disabled_bidirection_equation_generator.cc + print_dim_expr.cc) -cinn_cc_test(equation_value_match_trait_test SRCS - equation_value_match_trait_test.cc DEPS gtest glog) + cinn_cc_test(equation_value_match_trait_test SRCS + equation_value_match_trait_test.cc DEPS gtest glog) -cinn_cc_test(tree_test SRCS tree_test.cc DEPS gtest glog) + cinn_cc_test(tree_test SRCS tree_test.cc DEPS gtest glog) -cinn_cc_test(inline_translator_test SRCS inline_translator_test.cc DEPS gtest - glog) + cinn_cc_test(inline_translator_test SRCS inline_translator_test.cc DEPS gtest + glog) -message(STATUS "ADT srcs: ${cinnapi_src}") + cinn_cc_test( + dim_expr_test + SRCS + dim_expr_test.cc + DEPS + gtest + glog + cinncore) + + message(STATUS "ADT srcs: ${cinnapi_src}") + +endif() diff --git a/paddle/cinn/adt/adapter_dynamic_tensor.h b/paddle/cinn/adt/adapter_dynamic_tensor.h index b5408a95401e50..a36c64b19fc449 100644 --- a/paddle/cinn/adt/adapter_dynamic_tensor.h +++ b/paddle/cinn/adt/adapter_dynamic_tensor.h @@ -17,37 +17,30 @@ #include "paddle/cinn/adt/adt.h" #include "paddle/cinn/adt/symbolic_dim.h" -#include "paddle/cinn/hlir/framework/graph.h" -#include "paddle/cinn/hlir/framework/node.h" +#include "paddle/cinn/hlir/framework/pir/group.h" namespace cinn::adt::adapter { struct DynamicTensor final { - const hlir::framework::NodeData* node_data; - const hlir::framework::Graph* graph; + ::pir::Value node_data; + const hlir::framework::pir::Group* group; bool operator==(const DynamicTensor& other) const { - return this->node_data == other.node_data && this->graph == other.graph; + return this->node_data == other.node_data; } std::size_t GetRank() const { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(node_data->id())) - << "Can't find " << node_data->id() << " 's shape!"; - return shape_dict.at(node_data->id()).size(); + return cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data) + .size(); } const std::vector>& GetShape() const { - return graph->graph_ctx()->GetTensorDimExprs(node_data); + return group->graph_symbolic_dim_infer_ctx()->GetTensorDimExprs(node_data); } }; inline std::size_t GetHashValueImpl(const DynamicTensor& tensor) { - return hash_combine( - std::hash()(tensor.node_data), - std::hash()(tensor.graph)); + return std::hash<::pir::Value>()(tensor.node_data); } } // namespace cinn::adt::adapter diff --git a/paddle/cinn/adt/adapter_tensor.cc b/paddle/cinn/adt/adapter_tensor.cc new file mode 100644 index 00000000000000..464c45780dbecd --- /dev/null +++ b/paddle/cinn/adt/adapter_tensor.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/adt/adapter_tensor.h" +#include "glog/logging.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" + +namespace cinn::adt::adapter { + +std::size_t Tensor::GetRank() const { + return cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data) + .size(); +} + +std::vector Tensor::GetShape() const { + std::vector ret{}; + for (int dim_size : + cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data)) { + ret.emplace_back(dim_size); + } + return ret; +} + +std::size_t Tensor::GetNumel() const { + std::size_t ret = 1; + for (int dim_size : + cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data)) { + ret = ret * dim_size; + } + return ret; +} + +} // namespace cinn::adt::adapter diff --git a/paddle/cinn/adt/adapter_tensor.h b/paddle/cinn/adt/adapter_tensor.h index 2a6cc941afb89e..dbd2c2dcecfdbb 100644 --- a/paddle/cinn/adt/adapter_tensor.h +++ b/paddle/cinn/adt/adapter_tensor.h @@ -13,59 +13,28 @@ // limitations under the License. #pragma once -#include "glog/logging.h" #include "paddle/cinn/adt/adt.h" -#include "paddle/cinn/hlir/framework/graph.h" -#include "paddle/cinn/hlir/framework/node.h" +#include "paddle/pir/core/value.h" namespace cinn::adt::adapter { struct Tensor final { - const hlir::framework::NodeData* node_data; - const hlir::framework::Graph* graph; + ::pir::Value node_data; bool operator==(const Tensor& other) const { - return this->node_data == other.node_data && this->graph == other.graph; + return this->node_data == other.node_data; } - std::size_t GetRank() const { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(node_data->id())) - << "Can't find " << node_data->id() << " 's shape!"; - return shape_dict.at(node_data->id()).size(); - } + std::size_t GetRank() const; - const std::vector& GetShape() const { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(node_data->id())) - << "Can't find " << node_data->id() << " 's shape!"; - return shape_dict.at(node_data->id()); - } + std::vector GetShape() const; - std::size_t GetNumel() const { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(node_data->id())) - << "Can't find " << node_data->id() << " 's shape!"; - std::vector shape = shape_dict.at(node_data->id()); - std::size_t ret = 1; - for (int32_t dim_size : shape) { - ret = ret * dim_size; - } - return ret; - } + std::size_t GetNumel() const; }; inline std::size_t GetHashValueImpl(const Tensor& tensor) { - return hash_combine( - std::hash()(tensor.node_data), - std::hash()(tensor.graph)); + return std::hash<::pir::Value>()(tensor.node_data); } } // namespace cinn::adt::adapter diff --git a/paddle/cinn/adt/equation_value.h b/paddle/cinn/adt/equation_value.h index a876ffef1bf6d4..1d0c2a134423a5 100644 --- a/paddle/cinn/adt/equation_value.h +++ b/paddle/cinn/adt/equation_value.h @@ -20,11 +20,6 @@ #include "paddle/cinn/adt/equation.h" #include "paddle/cinn/adt/match.h" -namespace cinn::hlir::framework { -class Node; -class NodeData; -} // namespace cinn::hlir::framework - namespace cinn::adt { DEFINE_ADT_TAG(tPointer); diff --git a/paddle/cinn/adt/generate_map_expr.cc b/paddle/cinn/adt/generate_map_expr.cc index 5782902c775f35..124786fb2935b3 100644 --- a/paddle/cinn/adt/generate_map_expr.cc +++ b/paddle/cinn/adt/generate_map_expr.cc @@ -28,7 +28,12 @@ #include "paddle/cinn/adt/schedule_descriptor.h" #include "paddle/cinn/adt/symbolic_dim_infer_util.h" #include "paddle/cinn/adt/tree.h" +#include "paddle/cinn/hlir/framework/pir/group.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/runtime/flags.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" +#include "paddle/pir/dialect/shape/utils/shape_optimization_utils.h" #include "glog/logging.h" @@ -87,25 +92,24 @@ using LoopDescriptor4IterVarT = std::function; using AnchorTensor = Variable; using FakeOpPlaceHolders = List; -Op MakeOp(const hlir::framework::Node* op) { return {op}; } +Op MakeOp(const ::pir::Operation* op) { return {op}; } template -void VisitEachInputTensor(const hlir::framework::Node* op, - const DoEachT& DoEach) { - for (const auto& graph_edge : op->inlinks_in_order()) { - DoEach(graph_edge->source()->safe_as()); +void VisitEachInputTensor(const ::pir::Operation* op, const DoEachT& DoEach) { + for (std::size_t i = 0; i < op->num_operands(); ++i) { + DoEach(op->operand_source(i)); } } -List MakeOpStmtInputList(const hlir::framework::Node* op, - const hlir::framework::Graph* graph) { +List MakeOpStmtInputList(const ::pir::Operation* op, + const hlir::framework::pir::Group* group) { List ret{}; - VisitEachInputTensor(op, [&](const auto* tensor) { + VisitEachInputTensor(op, [&](const ::pir::Value& tensor) { if (FLAGS_cinn_map_expr_enable_dynamic_shape) { - ret->emplace_back(adapter::DynamicTensor{tensor, graph}); + ret->emplace_back(adapter::DynamicTensor{tensor, group}); } else { - ret->emplace_back(adapter::Tensor{tensor, graph}); + ret->emplace_back(adapter::Tensor{tensor}); } }); @@ -113,22 +117,21 @@ List MakeOpStmtInputList(const hlir::framework::Node* op, } template -void VisitEachOutputTensor(const hlir::framework::Node* op, - const DoEachT& DoEach) { - for (const auto& graph_edge : op->outlinks_in_order()) { - DoEach(graph_edge->sink()->safe_as()); +void VisitEachOutputTensor(const ::pir::Operation* op, const DoEachT& DoEach) { + for (std::size_t i = 0; i < op->num_results(); ++i) { + DoEach(const_cast<::pir::Operation*>(op)->result(i)); } } -List MakeOpStmtOutputList(const hlir::framework::Node* op, - const hlir::framework::Graph* graph) { +List MakeOpStmtOutputList(const ::pir::Operation* op, + const hlir::framework::pir::Group* group) { List ret{}; - VisitEachOutputTensor(op, [&](const auto* tensor) { + VisitEachOutputTensor(op, [&](const ::pir::Value& tensor) { if (FLAGS_cinn_map_expr_enable_dynamic_shape) { - ret->emplace_back(adapter::DynamicTensor{tensor, graph}); + ret->emplace_back(adapter::DynamicTensor{tensor, group}); } else { - ret->emplace_back(adapter::Tensor{tensor, graph}); + ret->emplace_back(adapter::Tensor{tensor}); } }); @@ -136,38 +139,30 @@ List MakeOpStmtOutputList(const hlir::framework::Node* op, } template -void VisitEachOpStmt( - const std::shared_ptr& group, - const DoEachT& DoEach) { - // Note - for (const auto* op : group->CollectNodes()) { +void VisitEachOpStmt(const std::shared_ptr& group, + const DoEachT& DoEach) { + for (const auto* op : group->CollectOps()) { DoEach(OpStmt{MakeOp(op), - MakeOpStmtInputList(op, group->graph_), - MakeOpStmtOutputList(op, group->graph_)}); + MakeOpStmtInputList(op, group.get()), + MakeOpStmtOutputList(op, group.get())}); } } -hlir::framework::OpPatternKind GetOpPatternKind( - const hlir::framework::Node* node) { - static const hlir::framework::OpValueType& - op_pattern_dict = - hlir::framework::Operator::GetAttrs( - "OpPattern"); - auto kind = op_pattern_dict[node->op()]; - return kind; +hlir::framework::OpPatternKind GetOpPatternKind(const ::pir::Operation* node) { + return hlir::framework::pir::CompatibleInfo::OpKind(*node); } bool CollectRewritedReductionOpStmts(const OpStmt& op_stmt, List* ret) { const auto& [op, inputs, outputs] = op_stmt.tuple(); - CHECK(op.Has()); - if (GetOpPatternKind(op.Get()) == + CHECK(op.Has()); + if (GetOpPatternKind(op.Get()) == hlir::framework::OpPatternKind::kReduction) { - tReduceInit init_op{ - op.Get()}; + tReduceInit init_op{ + op.Get()}; (*ret)->emplace_back(OpStmt{init_op, List{}, outputs}); - tReduceAcc acc_op{ - op.Get()}; + tReduceAcc acc_op{ + op.Get()}; (*ret)->emplace_back(OpStmt{acc_op, inputs, outputs}); return true; } else { @@ -183,7 +178,7 @@ void CollectRewritedOpStmts(const OpStmt& op_stmt, List* ret) { } List MakeOpStmts( - const std::shared_ptr& group) { + const std::shared_ptr& group) { List ret{}; VisitEachOpStmt(group, [&](const auto& op_stmt) { @@ -212,15 +207,14 @@ std::shared_ptr MakeIGroup(const AnchorGroup& igroup_spec) { std::shared_ptr direction_equation_generator{ new NaiveBidirectionEquationGenerator{igroup_spec.op_stmts, igroup_spec.EquationCtx4OpStmt}}; - CheckEquationSolvable( - igroup_spec, direction_equation_generator); + CheckEquationSolvable(igroup_spec, direction_equation_generator); return std::make_shared(igroup_spec.op_stmts, igroup_spec.anchor_index, igroup_spec.EquationCtx4OpStmt); } std::vector> GenerateIGroups( - const std::shared_ptr& group) { + const std::shared_ptr& group) { std::vector> ret{}; List op_stmts = MakeOpStmts(group); @@ -234,7 +228,7 @@ std::vector> GenerateIGroups( } std::shared_ptr GenerateKGroups( - const std::shared_ptr& group, + const std::shared_ptr& group, const std::vector>& igroups) { CHECK_EQ(igroups.size(), 1); return std::make_shared(group, igroups); @@ -274,8 +268,7 @@ std::shared_ptr SolveEquationsThenReturnCtx( GraphView merged_view = igroup_view.Merge(sd_equation_graph_view); const auto& init_var2value = MakeSdIterator2Iterator(*igroup); - auto ctx = std::make_shared( - init_var2value); + auto ctx = std::make_shared(init_var2value); std::vector starts{}; for (const auto& loop_iterator : *igroup->loop_iterators()) { @@ -350,36 +343,34 @@ Tensor GetAnchorTensor(const std::shared_ptr& igroup) { } template -void VisitInputTensor(const hlir::framework::Graph::Group& group, +void VisitInputTensor(const hlir::framework::pir::Group& group, const DoEachT& DoEach) { - for (const auto* node_data : group.GetInputNodeDatas()) { - DoEach(node_data, group.graph_); + for (const ::pir::Value& node_data : group.GetInputOpValues()) { + DoEach(node_data); } } template -void VisitOutputTensor(const hlir::framework::Graph::Group& group, +void VisitOutputTensor(const hlir::framework::pir::Group& group, const DoEachT& DoEach) { - for (const auto& node_data : group.GetOutputNodeDatas()) { - DoEach(node_data, group.graph_); + for (const ::pir::Value& node_data : group.GetOutputOpValues()) { + DoEach(node_data); } } List MakeInputTensors(const std::shared_ptr& kgroup) { List ret{}; - VisitInputTensor(*kgroup->cinn_group(), - [&](const auto* node_data, const auto* graph) { - ret->emplace_back(adapter::Tensor{node_data, graph}); - }); + VisitInputTensor(*kgroup->cinn_group(), [&](const ::pir::Value& node_data) { + ret->emplace_back(adapter::Tensor{node_data}); + }); return ret; } List MakeOutputTensors(const std::shared_ptr& kgroup) { List ret{}; - VisitOutputTensor(*kgroup->cinn_group(), - [&](const auto* node_data, const auto* graph) { - ret->emplace_back(adapter::Tensor{node_data, graph}); - }); + VisitOutputTensor(*kgroup->cinn_group(), [&](const ::pir::Value& node_data) { + ret->emplace_back(adapter::Tensor{node_data}); + }); return ret; } @@ -444,7 +435,7 @@ MapExpr GenerateMapExpr(const std::shared_ptr& kgroup) { } // namespace MapExpr GenerateMapExpr( - const std::shared_ptr& group) { + const std::shared_ptr& group) { const auto& igroups = GenerateIGroups(group); const auto& kgroup = GenerateKGroups(group, igroups); @@ -453,16 +444,35 @@ MapExpr GenerateMapExpr( } void TryGenerateMapExprFromGraph( - const std::shared_ptr& graph) { + const hlir::framework::pir::GroupList& groups) { if (!FLAGS_cinn_enable_map_expr) { return; } - graph->set_graph_ctx(adt::InferSymbolicDim(graph.get())); - for (const auto& fusion_group : graph->fusion_groups) { + for (const auto& fusion_group : groups) { + // ADT_TODO : Fake pointer here, remove this later + ::pir::SymbolicDimMgr* symbolic_dim_mgr; + fusion_group->set_graph_symbolic_dim_infer_ctx( + std::make_unique(fusion_group.get(), + symbolic_dim_mgr)); const auto& map_expr = GenerateMapExpr(fusion_group); VLOG(1) << ToTxtString(map_expr, fusion_group->group_id); fusion_group->set_map_expr_ctx(std::make_shared(map_expr)); } } +void TryGenerateMapExprFromGroup( + const std::shared_ptr& fusion_group) { + if (!FLAGS_cinn_enable_map_expr) { + return; + } + // ADT_TODO : Fake pointer here, remove this later + ::pir::SymbolicDimMgr* symbolic_dim_mgr; + fusion_group->set_graph_symbolic_dim_infer_ctx( + std::make_unique(fusion_group.get(), + symbolic_dim_mgr)); + const auto& map_expr = GenerateMapExpr(fusion_group); + VLOG(1) << ToTxtString(map_expr, fusion_group->group_id); + fusion_group->set_map_expr_ctx(std::make_shared(map_expr)); +} + } // namespace cinn::adt diff --git a/paddle/cinn/adt/generate_map_expr.h b/paddle/cinn/adt/generate_map_expr.h index cd2970559cb6b0..e9235dd625270b 100644 --- a/paddle/cinn/adt/generate_map_expr.h +++ b/paddle/cinn/adt/generate_map_expr.h @@ -14,19 +14,25 @@ #pragma once -#include "paddle/cinn/adt/m_ir.h" +#include + #include "paddle/cinn/adt/map_expr.h" -#include "paddle/cinn/hlir/framework/graph.h" -namespace cinn::adt { +namespace cinn::hlir::framework::pir { + +struct Group; +using GroupList = std::vector>; -class IGroup; -class KGroup; +} // namespace cinn::hlir::framework::pir + +namespace cinn::adt { MapExpr GenerateMapExpr( - const std::shared_ptr& group); + const std::shared_ptr& group); + +void TryGenerateMapExprFromGraph(const hlir::framework::pir::GroupList& groups); -void TryGenerateMapExprFromGraph( - const std::shared_ptr& graph); +void TryGenerateMapExprFromGroup( + const std::shared_ptr& fusion_group); } // namespace cinn::adt diff --git a/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc b/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc index 83fc92a203c0f6..0aa0e0d08c8ddb 100644 --- a/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc +++ b/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc @@ -16,57 +16,148 @@ #include "paddle/cinn/adt/dim_expr_simplifier.h" #include "paddle/cinn/adt/unique_id.h" -#include "paddle/cinn/common/graph_utils.h" -#include "paddle/cinn/hlir/framework/graph.h" -#include "paddle/cinn/hlir/framework/node.h" +#include "paddle/cinn/adt/arithmetic.h" +#include "paddle/cinn/adt/logical.h" +#include "paddle/cinn/adt/adt.h" +#include "paddle/cinn/hlir/framework/pir/group.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" +#include "paddle/pir/dialect/shape/utils/shape_optimization_utils.h" namespace cinn::adt::config { namespace { -const std::vector& GetShape(const hlir::framework::Graph* graph, - const hlir::framework::NodeData* tensor) { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(tensor->id())) - << "Can't find " << tensor->id() << " 's shape!"; - return shape_dict.at(tensor->id()); +// clang-format off +// Dim equations' configuration: +// +// ShapeDialectConstraints = [ShapeDialectConstraint] +// ShapeDialectConstraint = Equal ShapeDialectDimExpr ShapeDialectDimExpr +// +// ShapeDialectDimExpr = ShapeDialectAtomicDim +// | Product ShapeDialectAtomicDim +// +// ShapeDialectAtomicDim = int64_t | ShapeDialectSymbolicDim +// ShapeDialectSymbolicDim = (::pir::Value, tAxis int) +// +// +// Dim equations' variables: +// +// ShapeDialectSymbolicDim +// +// Dim equations' functions: +// DimFunction = DimIdentity (tOut ShapeDialectSymbolicDim) (tIn ShapeDialectSymbolicDim) +// | DimProduct (tOut ShapeDialectSymbolicDim) [tIn ShapeDialectSymbolicDim] +// | DimReciprocal (tOut ShapeDialectSymbolicDim) (tIn ShapeDialectSymbolicDim) +// +// Dim equations' solutions: +// +// DimExpr +// clang-format on + +// ShapeDialectSymbolicDim = (::pir::Value, tAxis int) +struct ShapeDialectSymbolicDim { + ::pir::Value tensor; + int axis; + + bool operator==(const ShapeDialectSymbolicDim& other) const { + return this->tensor == other.tensor && this->axis == other.tensor; + } +}; +// ShapeDialectAtomicDim = int64_t | ShapeDialectSymbolicDim +DEFINE_ADT_UNION(ShapeDialectAtomicDim, std::int64_t, ShapeDialectSymbolicDim); +// ShapeDialectDimExpr = ShapeDialectAtomicDim +// | Product ShapeDialectAtomicDim +DEFINE_ADT_UNION(ShapeDialectDimExpr, + ShapeDialectAtomicDim, + Product); +// ShapeDialectConstraint = Equal ShapeDialectDimExpr ShapeDialectDimExpr +using ShapeDialectConstraint = Equal; +// ShapeDialectConstraints = [ShapeDialectConstraint] +using ShapeDialectConstraints = List; + +template +struct DimIdentity; + +// DimIdentity (tOut ShapeDialectSymbolicDim) (tIn ShapeDialectSymbolicDim) +template<> +struct DimIdentity, tIn> + : public Tuple, tIn> { + using Tuple, tIn>::Tuple; +}; + +template +struct DimProduct; + +// DimProduct (tOut ShapeDialectSymbolicDim) [tIn ShapeDialectSymbolicDim] +template<> +struct DimProduct, List>> + : public Tuple, List>> { + using Tuple, List>>::Tuple; +}; + +// DimReciprocal (tOut ShapeDialectSymbolicDim) (tIn ShapeDialectSymbolicDim) +template<> +struct DimReciprocal, tIn> + : public Tuple, tIn> { + using Tuple, tIn>::Tuple; +}; + +// DimFunction = DimIdentity (tOut ShapeDialectSymbolicDim) (tIn ShapeDialectSymbolicDim) +// | DimProduct (tOut ShapeDialectSymbolicDim) [tIn ShapeDialectSymbolicDim] +// | DimReciprocal (tOut ShapeDialectSymbolicDim) (tIn ShapeDialectSymbolicDim) + +DEFINE_ADT_UNION(DimFunction, + DimIdentity, tIn>, + DimProduct, List>>, + DimReciprocal, tIn>); +} + } -std::size_t GetTensorRank(const hlir::framework::Graph* graph, - const hlir::framework::NodeData* tensor) { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(tensor->id())) - << "Can't find " << tensor->id() << " 's shape!"; - return shape_dict.at(tensor->id()).size(); +namespace std { + +template<> +struct hash final { + using namespace cinn::adt::config; + std::size_t operator()(const ShapeDialectSymbolicDim& dim) const { + return hash_combine(std::hash<::pir::Value>()(dim.tensor), dim.axis); + } +}; + +} + +namespace cinn::adt::config { + +namespace { + +std::vector GetShape(const ::pir::Value& tensor) { + std::vector tensor_shape = + hlir::framework::pir::CompatibleInfo::ValueShape(tensor); + std::vector ret{}; + for (int32_t dim : tensor_shape) { + ret.push_back(dim); + } + return ret; } -std::vector GetOpInputRanks(const hlir::framework::Graph* graph, - const hlir::framework::Node* node) { +std::size_t GetTensorRank(const ::pir::Value& tensor) { + return hlir::framework::pir::CompatibleInfo::ValueShape(tensor).size(); +} + +std::vector GetOpInputRanks(const ::pir::Operation* node) { std::vector ret{}; - for (const auto& graph_edge : node->inlinks_in_order()) { - const hlir::framework::NodeData* tensor = - graph_edge->source()->safe_as(); - ret.emplace_back(GetTensorRank(graph, tensor)); + for (const ::pir::Value& tensor : node->operands_source()) { + ret.emplace_back(GetTensorRank(tensor)); } return ret; } -std::vector GetTopoOrderOpNodes( - const hlir::framework::Graph* graph) { - std::vector ret{}; - std::vector topo_nodes = - std::get<0>(graph->topological_order()); - for (const common::GraphNode* graph_node : topo_nodes) { - const hlir::framework::Node* op_node = - graph_node->safe_as(); - // if node is NodeData or not op, continue. - if (!op_node || op_node->op() == nullptr) { - continue; - } +std::vector GetTopoOrderOpNodes( + const hlir::framework::pir::Group* group) { + std::vector ret{}; + for (const ::pir::Operation* op_node : group->ops) { ret.emplace_back(op_node); } return ret; @@ -74,44 +165,34 @@ std::vector GetTopoOrderOpNodes( } // namespace -void GraphSymbolicDimInferCtx::InitOp2TensorRanks() { - for (const hlir::framework::Node* op_node : GetTopoOrderOpNodes(graph_)) { - const auto& input_ranks = GetOpInputRanks(graph_, op_node); - if (op2input_ranks_.find(op_node) == op2input_ranks_.end()) { - op2input_ranks_.emplace(op_node, input_ranks); - } else { - CHECK(input_ranks == op2input_ranks_.at(op_node)); - } - } -} - namespace { std::unordered_set GetAllOutputNames( - const std::vector& nodes) { + const std::vector& nodes) { std::unordered_set output_names; - for (const auto* node : nodes) { - for (const auto& link : node->outlinks()) { - const auto* out_node = link->sink()->safe_as(); - output_names.emplace(out_node->id()); + for (const auto* op_node : nodes) { + for (const ::pir::Value& out_node : + const_cast<::pir::Operation*>(op_node)->results()) { + output_names.emplace( + hlir::framework::pir::CompatibleInfo::ValueName(out_node)); } } return output_names; } -std::vector GetFeedList( - const std::vector& nodes, +std::vector<::pir::Value> GetFeedList( + const std::vector& op_nodes, const std::unordered_set& out_names) { - std::vector ret{}; + std::vector<::pir::Value> ret{}; // if the op's input var name cannot found in out_names, it is the group's // feed var std::unordered_set feed_names; - for (const auto* node : nodes) { - for (const auto& link : node->inlinks()) { - const auto* in_node = - link->source()->safe_as(); - if (!out_names.count(in_node->id()) && !feed_names.count(in_node->id())) { - feed_names.emplace(in_node->id()); + for (const auto* op_node : op_nodes) { + for (const ::pir::Value in_node : op_node->operands_source()) { + const auto& node_id = + hlir::framework::pir::CompatibleInfo::ValueName(in_node); + if (!out_names.count(node_id) && !feed_names.count(node_id)) { + feed_names.emplace(node_id); ret.emplace_back(in_node); } } @@ -120,15 +201,13 @@ std::vector GetFeedList( } std::vector> MakeDimExprForTensor( - const hlir::framework::Graph* graph, - const hlir::framework::NodeData* node_data) { + const ::pir::Value& node_data) { std::vector> ret{}; - const std::vector& shape = GetShape(graph, node_data); + std::vector shape = GetShape(node_data); for (std::size_t i = 0; i < shape.size(); ++i) { if (i == 0) { - static DimExpr temp_elementwise_dim_expr{ - SymbolicDim{UniqueId::New()}}; + static DimExpr temp_elementwise_dim_expr{SymbolicDim{UniqueId::New()}}; ret.emplace_back(temp_elementwise_dim_expr); } else { ret.emplace_back(DimExpr{shape.at(i)}); @@ -139,68 +218,93 @@ std::vector> MakeDimExprForTensor( } // namespace -void GraphSymbolicDimInferCtx::InitGraphInputDimExpr() { - std::vector topo_op_nodes = - GetTopoOrderOpNodes(graph_); - std::vector feed_list = - GetFeedList(topo_op_nodes, GetAllOutputNames(topo_op_nodes)); - for (const hlir::framework::NodeData* node_data : feed_list) { - CHECK( - tensor2dim_exprs_ - .emplace(node_data, MakeDimExprForTensor(graph_, node_data)) - .second); +namespace { + +template +void VisitEachTensorPair(const ::pir::Operation* op_node, + const DoEachT& DoEach) { + std::vector<::pir::Value> all_tensors{}; + for (const ::pir::Value tensor : op_node->operands_source()) { + all_tensors.emplace_back(tensor); } + for (const ::pir::Value tensor : op_node->results()) { + all_tensors.emplace_back(tensor); + } + for (std::size_t i = 0; i < all_tensors.size(); ++i) { + for (std::size_t j = i + 1; j < all_tensors.size(); ++j) { + DoEach(all_tensors.at(i), all_tensors.at(j)); + } + } +} + +void BuildTensorShapeDialectConstraints( + const ::pir::Value& lhs, + const ::pir::Value& rhs, + const ::pir::SymbolicDimMgr* symbolic_dim_mgr, + ShapeDialectConstraints* ret) { + const auto& lhs_symbolic_dim_ops = + symbolic_dim_mgr->CreateSymbolicDimsForRankedValue(lhs); + ADT_TODO(); } -const std::vector& GraphSymbolicDimInferCtx::GetInTensorsRanks( - const hlir::framework::Node* node) const { - const auto& iter = op2input_ranks_.find(node); - CHECK(iter != op2input_ranks_.end()); - return iter->second; -} - -std::uint64_t GraphSymbolicDimInferCtx::GetNumOutTensors( - const hlir::framework::Node* node) const { - return node->outlinks_in_order().size(); -} - -const DimExpr& GraphSymbolicDimInferCtx::GetInputDimExpr( - const hlir::framework::Node* node, - std::size_t arg_idx, - std::size_t dim_idx) const { - const auto& edges = node->inlinks_in_order(); - CHECK_LT(arg_idx, edges.size()); - const hlir::framework::NodeData* tensor = - edges.at(arg_idx)->source()->safe_as(); - const auto& iter = tensor2dim_exprs_.find(tensor); - CHECK(iter != tensor2dim_exprs_.end()); - CHECK_LT(dim_idx, iter->second.size()); - const auto& opt_dim_expr = iter->second.at(dim_idx); - CHECK(opt_dim_expr.has_value()); - return opt_dim_expr.value(); -} - -void GraphSymbolicDimInferCtx::SetOutputDimExpr( - const hlir::framework::Node* node, - std::size_t arg_idx, - std::size_t dim_idx, - const DimExpr& value) { - const auto& edges = node->outlinks_in_order(); - CHECK_LT(arg_idx, edges.size()); - const hlir::framework::NodeData* tensor = - edges.at(arg_idx)->sink()->safe_as(); - std::size_t rank = GetTensorRank(graph_, tensor); - CHECK_LT(dim_idx, rank); - auto* opt_symbolic_dims = &tensor2dim_exprs_[tensor]; - if (dim_idx >= opt_symbolic_dims->size()) { - opt_symbolic_dims->resize(dim_idx + 1); +void BuildOpShapeDialectConstraints( + const ::pir::Operation* op_node, + const ::pir::SymbolicDimMgr* symbolic_dim_mgr, + ShapeDialectConstraints* ret) { + VisitEachTensorPair( + op_node, [&](const ::pir::Value& lhs, const ::pir::Value& rhs) { + BuildTensorShapeDialectConstraints(lhs, rhs, symbolic_dim_mgr, ret); + }); +} + +ShapeDialectConstraints BuildGraphShapeDialectConstraints( + const cinn::hlir::framework::pir::Group* group, + const ::pir::SymbolicDimMgr* symbolic_dim_mgr) { + ShapeDialectConstraints ret{}; + for (const ::pir::Operation* op_node : group->ops) { + BuildOpShapeDialectConstraints(op_node, symbolic_dim_mgr, &ret); } - opt_symbolic_dims->at(dim_idx) = SimplifyDimExpr(value); + return ret; +} + +// ADT_TODO(); +using GraphView = + EquationGraphTopoWalker; + +GraphView MakeEquationGraphView(const ShapeDialectConstraints& constraints, + const cinn::hlir::framework::pir::Group* group, + const ::pir::SymbolicDimMgr* symbolic_dim_mgr) { + ADT_TODO(); +} + +std::unordered_map MakeEquationStartExpr( + const GraphView& graph_view, + const cinn::hlir::framework::pir::Group* group, + const ::pir::SymbolicDimMgr* symbolic_dim_mgr) { + ADT_TODO(); } -const hlir::framework::AttrMapType& GraphSymbolicDimInferCtx::GetAttributeMap( - const hlir::framework::Node* op_node) const { - return op_node->attrs.attr_store; +std::unordered_map<::pir::Value, std::vector>> +SolveShapeDialectConstraints( + const GraphView& graph_view, + const std::unordered_map& + equation_start) { + ADT_TODO(); +} + +} // namespace + +void GraphSymbolicDimInferCtx::InitTensorDimExpr() { + ShapeDialectConstraints constraints = + BuildGraphShapeDialectConstraints(group_, symbolic_dim_mgr_); + + const auto& graph_view = + MakeEquationGraphView(constraints, group_, symbolic_dim_mgr_); + + const auto& equation_start = + MakeEquationStartExpr(graph_view, group_, symbolic_dim_mgr_); + + tensor2dim_exprs_ = SolveShapeDialectConstraints(graph_view, equation_start); } } // namespace cinn::adt::config diff --git a/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h b/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h index 58bda1f4e9b2eb..24d866d35a383c 100644 --- a/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h +++ b/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h @@ -19,13 +19,16 @@ #include #include "paddle/cinn/adt/dim_expr.h" -#include "paddle/cinn/hlir/framework/node.h" +#include "paddle/cinn/utils/type_defs.h" +#include "paddle/pir/core/value.h" +namespace pir { +class Operation; +class SymbolicDimMgr; +} // namespace pir -namespace cinn::hlir::framework { -class Graph; -class NodeData; -class Node; -} // namespace cinn::hlir::framework +namespace cinn::hlir::framework::pir { +struct Group; +} // namespace cinn::hlir::framework::pir namespace cinn::adt::config { @@ -34,48 +37,29 @@ class GraphSymbolicDimInferCtx { GraphSymbolicDimInferCtx(const GraphSymbolicDimInferCtx&) = delete; GraphSymbolicDimInferCtx(GraphSymbolicDimInferCtx&&) = delete; - explicit GraphSymbolicDimInferCtx(const hlir::framework::Graph* graph) - : graph_(graph) { - InitOp2TensorRanks(); - InitGraphInputDimExpr(); + explicit GraphSymbolicDimInferCtx( + const cinn::hlir::framework::pir::Group* group, + const ::pir::SymbolicDimMgr* symbolic_dim_mgr) + : group_(group), symbolic_dim_mgr_(symbolic_dim_mgr) { + InitTensorDimExpr(); } - const hlir::framework::Graph* graph() const { return graph_; } - - const std::vector& GetInTensorsRanks( - const hlir::framework::Node* node) const; - - std::uint64_t GetNumOutTensors(const hlir::framework::Node* node) const; - - const DimExpr& GetInputDimExpr(const hlir::framework::Node* node, - std::size_t arg_idx, - std::size_t dim_idx) const; + const cinn::hlir::framework::pir::Group* group() const { return group_; } const std::vector>& GetTensorDimExprs( - const hlir::framework::NodeData* tensor) const { + const ::pir::Value tensor) const { const auto& iter = tensor2dim_exprs_.find(tensor); CHECK(iter != tensor2dim_exprs_.end()); return iter->second; } - void SetOutputDimExpr(const hlir::framework::Node* node, - std::size_t arg_idx, - std::size_t dim_idx, - const DimExpr& value); - - const hlir::framework::AttrMapType& GetAttributeMap( - const hlir::framework::Node* node) const; - private: - void InitOp2TensorRanks(); - void InitGraphInputDimExpr(); + void InitTensorDimExpr(); - const hlir::framework::Graph* graph_; - std::unordered_map>> + const cinn::hlir::framework::pir::Group* group_; + const ::pir::SymbolicDimMgr* symbolic_dim_mgr_; + std::unordered_map<::pir::Value, std::vector>> tensor2dim_exprs_; - std::unordered_map> - op2input_ranks_; }; } // namespace cinn::adt::config diff --git a/paddle/cinn/adt/inline_translator.h b/paddle/cinn/adt/inline_translator.h index 9560750bd2882e..d3910791f32b04 100644 --- a/paddle/cinn/adt/inline_translator.h +++ b/paddle/cinn/adt/inline_translator.h @@ -15,47 +15,12 @@ #pragma once #include "paddle/cinn/adt/adt.h" +#include "paddle/cinn/adt/inline_translator_trait.h" #include "paddle/cinn/adt/map_expr.h" #include "paddle/cinn/adt/tree.h" namespace cinn::adt { -template