From f7c775c602c9c106ed35b41e3e09ec1136e1a029 Mon Sep 17 00:00:00 2001 From: Tian Chao Date: Tue, 7 Nov 2023 09:48:18 +0000 Subject: [PATCH 1/3] redefine MapExpr to support conditional kernels --- paddle/cinn/adt/arithmetic.h | 6 ++--- paddle/cinn/adt/ast.h | 22 ++++++++++++++++++ paddle/cinn/adt/logical.h | 2 +- paddle/cinn/adt/map_expr.h | 44 ++++++++++++++++++++++++------------ paddle/cinn/adt/tags.h | 2 ++ 5 files changed, 56 insertions(+), 20 deletions(-) create mode 100644 paddle/cinn/adt/ast.h diff --git a/paddle/cinn/adt/arithmetic.h b/paddle/cinn/adt/arithmetic.h index 5bebcfba657376..904b5c63358164 100644 --- a/paddle/cinn/adt/arithmetic.h +++ b/paddle/cinn/adt/arithmetic.h @@ -19,8 +19,6 @@ namespace cinn::adt { DEFINE_ADT_UNARY(Negative); -template -using Neg = Negative; DEFINE_ADT_UNARY(Reciprocal); DEFINE_ADT_BINARY(Add); DEFINE_ADT_BINARY(Sub); @@ -28,7 +26,7 @@ DEFINE_ADT_BINARY(Mul); DEFINE_ADT_BINARY(Div); DEFINE_ADT_BINARY(Mod); -// Arithmetic T = Neg T +// Arithmetic T = Negative T // | Add T T // | Sub T T // | Mul T T @@ -36,7 +34,7 @@ DEFINE_ADT_BINARY(Mod); // | Mod T T template DEFINE_ADT_UNION(Arithmetic, - Neg, + Negative, Add, Sub, Mul, diff --git a/paddle/cinn/adt/ast.h b/paddle/cinn/adt/ast.h new file mode 100644 index 00000000000000..5423964773830f --- /dev/null +++ b/paddle/cinn/adt/ast.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include "paddle/cinn/adt/adt.h" +#include "paddle/cinn/adt/arithmetic.h" + +namespace cinn::adt { + +template +struct Let final { + List> var2value; + BodyT body; +}; + +template +struct If final { + ConditionT condition; + TrueValueT true_value; + std::optional false_value; +}; + +} diff --git a/paddle/cinn/adt/logical.h b/paddle/cinn/adt/logical.h index 2ab13f0214874b..c65a7b1380acd8 100644 --- a/paddle/cinn/adt/logical.h +++ b/paddle/cinn/adt/logical.h @@ -42,6 +42,6 @@ template DEFINE_ADT_UNION(LogicalOp, And, Or, Not); template -using Logical = Tree>; +using Logical = Tree>; } // namespace cinn::adt diff --git a/paddle/cinn/adt/map_expr.h b/paddle/cinn/adt/map_expr.h index de7ebe39acbdb6..37a4c091e3b207 100644 --- a/paddle/cinn/adt/map_expr.h +++ b/paddle/cinn/adt/map_expr.h @@ -19,6 +19,7 @@ #include "paddle/cinn/adt/adapter_dynamic_tensor.h" #include "paddle/cinn/adt/adapter_tensor.h" #include "paddle/cinn/adt/adt.h" +#include "paddle/cinn/adt/ast.h" #include "paddle/cinn/adt/arithmetic.h" #include "paddle/cinn/adt/equation_value.h" #include "paddle/cinn/adt/logical.h" @@ -181,29 +182,42 @@ class AnchoredMapStmt final : public Tuple, } }; -DEFINE_ADT_UNION(GenericDim, SymbolicDim, std::int64_t); -using KernelCondition = Logical>; +using CppVar = tVar; -template -class ConditionalAnchoredMapStmt - : public Tuple, tTrue, tFalse> { - public: - using Tuple, tTrue, tFalse>::Tuple; -}; - -using KernelBody = Tree; -// Kernel = (KernelBody, In [Tensor], Out [Tensor]) +// Kernel = (AnchoredMapStmt, In [Tensor], Out [Tensor], [SymbolicDim]) class Kernel final : public Tuple, tIn>, - tOut>> { + tOut>, + List>> { public: - using Tuple, tIn>, tOut>>:: + using Tuple, tIn>, tOut>, List>:: Tuple; }; -// MapExpr = Kernel; -using MapExpr = Kernel; +template +struct ConditionalEntries { + List, T, T>> conditional_entries; +}; + +struct ShapeInferExpr final { + List> output_shape_expr; + ConditionalEntries temp_storage_expr; +}; + +struct GetTensorShapeDim final { + Tensor tensor; + std::int64_t aixs; + DimExpr symbolic_dim_expr; +}; + +template +using WithRuntimeTensorShapeDim = Let; + +struct MapExpr final { + WithRuntimeTensorShapeDim shape_infer_expr; + WithRuntimeTensorShapeDim> host_kernel_expr; +}; } // namespace adt } // namespace cinn diff --git a/paddle/cinn/adt/tags.h b/paddle/cinn/adt/tags.h index 641d7cad316b2d..a6b707075ea392 100644 --- a/paddle/cinn/adt/tags.h +++ b/paddle/cinn/adt/tags.h @@ -39,4 +39,6 @@ DEFINE_ADT_TAG(tHasNoConflictValue); DEFINE_ADT_TAG(tReduceInit); DEFINE_ADT_TAG(tReduceAcc); +DEFINE_ADT_TAG(tVar); + } // namespace cinn::adt From 02689ac9980af0e7f2e350bfaee92d030a7b1a14 Mon Sep 17 00:00:00 2001 From: Tian Chao Date: Tue, 7 Nov 2023 09:49:40 +0000 Subject: [PATCH 2/3] fix typo --- paddle/cinn/adt/logical.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/cinn/adt/logical.h b/paddle/cinn/adt/logical.h index c65a7b1380acd8..5685f639c3ed97 100644 --- a/paddle/cinn/adt/logical.h +++ b/paddle/cinn/adt/logical.h @@ -42,6 +42,6 @@ template DEFINE_ADT_UNION(LogicalOp, And, Or, Not); template -using Logical = Tree>; +using Logical = Tree>>; } // namespace cinn::adt From 42c96a156c6050f45044e91eccc1ff18b35d8cf9 Mon Sep 17 00:00:00 2001 From: Tian Chao Date: Tue, 7 Nov 2023 09:50:14 +0000 Subject: [PATCH 3/3] fix typo --- paddle/cinn/adt/logical.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/cinn/adt/logical.h b/paddle/cinn/adt/logical.h index 5685f639c3ed97..2ab13f0214874b 100644 --- a/paddle/cinn/adt/logical.h +++ b/paddle/cinn/adt/logical.h @@ -42,6 +42,6 @@ template DEFINE_ADT_UNION(LogicalOp, And, Or, Not); template -using Logical = Tree>>; +using Logical = Tree>; } // namespace cinn::adt