diff --git a/paddle/cinn/adt/arithmetic.h b/paddle/cinn/adt/arithmetic.h index 5bebcfba657376..904b5c63358164 100644 --- a/paddle/cinn/adt/arithmetic.h +++ b/paddle/cinn/adt/arithmetic.h @@ -19,8 +19,6 @@ namespace cinn::adt { DEFINE_ADT_UNARY(Negative); -template -using Neg = Negative; DEFINE_ADT_UNARY(Reciprocal); DEFINE_ADT_BINARY(Add); DEFINE_ADT_BINARY(Sub); @@ -28,7 +26,7 @@ DEFINE_ADT_BINARY(Mul); DEFINE_ADT_BINARY(Div); DEFINE_ADT_BINARY(Mod); -// Arithmetic T = Neg T +// Arithmetic T = Negative T // | Add T T // | Sub T T // | Mul T T @@ -36,7 +34,7 @@ DEFINE_ADT_BINARY(Mod); // | Mod T T template DEFINE_ADT_UNION(Arithmetic, - Neg, + Negative, Add, Sub, Mul, diff --git a/paddle/cinn/adt/ast.h b/paddle/cinn/adt/ast.h new file mode 100644 index 00000000000000..5423964773830f --- /dev/null +++ b/paddle/cinn/adt/ast.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include "paddle/cinn/adt/adt.h" +#include "paddle/cinn/adt/arithmetic.h" + +namespace cinn::adt { + +template +struct Let final { + List> var2value; + BodyT body; +}; + +template +struct If final { + ConditionT condition; + TrueValueT true_value; + std::optional false_value; +}; + +} diff --git a/paddle/cinn/adt/map_expr.h b/paddle/cinn/adt/map_expr.h index de7ebe39acbdb6..37a4c091e3b207 100644 --- a/paddle/cinn/adt/map_expr.h +++ b/paddle/cinn/adt/map_expr.h @@ -19,6 +19,7 @@ #include "paddle/cinn/adt/adapter_dynamic_tensor.h" #include "paddle/cinn/adt/adapter_tensor.h" #include "paddle/cinn/adt/adt.h" +#include "paddle/cinn/adt/ast.h" #include "paddle/cinn/adt/arithmetic.h" #include "paddle/cinn/adt/equation_value.h" #include "paddle/cinn/adt/logical.h" @@ -181,29 +182,42 @@ class AnchoredMapStmt final : public Tuple, } }; -DEFINE_ADT_UNION(GenericDim, SymbolicDim, std::int64_t); -using KernelCondition = Logical>; +using CppVar = tVar; -template -class ConditionalAnchoredMapStmt - : public Tuple, tTrue, tFalse> { - public: - using Tuple, tTrue, tFalse>::Tuple; -}; - -using KernelBody = Tree; -// Kernel = (KernelBody, In [Tensor], Out [Tensor]) +// Kernel = (AnchoredMapStmt, In [Tensor], Out [Tensor], [SymbolicDim]) class Kernel final : public Tuple, tIn>, - tOut>> { + tOut>, + List>> { public: - using Tuple, tIn>, tOut>>:: + using Tuple, tIn>, tOut>, List>:: Tuple; }; -// MapExpr = Kernel; -using MapExpr = Kernel; +template +struct ConditionalEntries { + List, T, T>> conditional_entries; +}; + +struct ShapeInferExpr final { + List> output_shape_expr; + ConditionalEntries temp_storage_expr; +}; + +struct GetTensorShapeDim final { + Tensor tensor; + std::int64_t aixs; + DimExpr symbolic_dim_expr; +}; + +template +using WithRuntimeTensorShapeDim = Let; + +struct MapExpr final { + WithRuntimeTensorShapeDim shape_infer_expr; + WithRuntimeTensorShapeDim> host_kernel_expr; +}; } // namespace adt } // namespace cinn diff --git a/paddle/cinn/adt/tags.h b/paddle/cinn/adt/tags.h index 641d7cad316b2d..a6b707075ea392 100644 --- a/paddle/cinn/adt/tags.h +++ b/paddle/cinn/adt/tags.h @@ -39,4 +39,6 @@ DEFINE_ADT_TAG(tHasNoConflictValue); DEFINE_ADT_TAG(tReduceInit); DEFINE_ADT_TAG(tReduceAcc); +DEFINE_ADT_TAG(tVar); + } // namespace cinn::adt